Merge branch 'master' into Fix_FileWriter

This commit is contained in:
JaehunRyu 2020-02-28 16:15:03 +09:00
commit 56cafc2c6c
2383 changed files with 80255 additions and 49163 deletions

View File

@ -37,7 +37,6 @@
# v2: Build TF v2
#
# Feature and Third party library support options:
# xla: Build TF with XLA
# using_cuda: CUDA is available to build system.
# cuda: Build with full cuda support.
# rocm: Build with AMD GPU support (rocm).
@ -222,6 +221,19 @@ build --define=grpc_no_ares=true
# archives in -whole_archive -no_whole_archive.
build --noincompatible_remove_legacy_whole_archive
# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0
# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC:
# https://github.com/tensorflow/community/pull/179
build --noincompatible_prohibit_aapt1
# Enable XLA
build --action_env=TF_ENABLE_XLA=1
build --define=with_xla_support=true
# Keep config XLA until all build scripts are cleaned up.
build:xla --action_env=TF_ENABLE_XLA=1
build:xla --define=with_xla_support=true
# Modular TF build options
build:dynamic_kernels --define=dynamic_loaded_kernels=true
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
@ -307,29 +319,29 @@ build:v2 --action_env=TF2_BEHAVIOR=1
build --config=v2
test --config=v2
# Enable XLA
build:xla --action_env=TF_ENABLE_XLA=1
build:xla --define=with_xla_support=true
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
# Options when using remote execution
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
# Flag to enable remote config
common --experimental_repo_remote_exec
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
build:rbe --auth_enabled=true
build:rbe --auth_scope=https://www.googleapis.com/auth/cloud-source-tools
build:rbe --google_default_credentials
build:rbe --bes_backend=buildeventservice.googleapis.com
build:rbe --bes_best_effort=false
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
build:rbe --bes_timeout=600s
build:rbe --define=EXECUTOR=remote
build:rbe --distinct_host_configuration=false
build:rbe --flaky_test_attempts=3
build:rbe --jobs=200
build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com
build:rbe --remote_timeout=3600
build:rbe --spawn_strategy=remote,worker,standalone,local
test:rbe --test_env=USER=anon
build:rbe --distinct_host_configuration=false
# Attempt to minimize the amount of data transfer between bazel and the remote
# workers:
build:rbe --remote_download_toplevel
build:rbe_linux --config=rbe
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
@ -339,7 +351,6 @@ build:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
# Non-rbe settings we should include because we do not run configure
build:rbe_linux --config=xla
build:rbe_linux --config=avx_linux
build:rbe_linux --config=short_logs
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
@ -354,13 +365,14 @@ build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --config=rbe_linux
build:rbe_linux_cuda_nvcc --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010,@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010-gpu"
build:rbe_linux_cuda_nvcc --host_platform="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/cuda10.1-cudnn7"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/tensorrt6.0"
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7
@ -377,9 +389,8 @@ build:rbe_linux_py2 --python_path="/usr/bin/python2"
build:rbe_linux_py2 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py"
build:rbe_linux_py3 --config=rbe_linux
build:rbe_linux_py3 --repo_env=PYTHON_BIN_PATH="/usr/bin/python3"
build:rbe_linux_py3 --python_path="/usr/bin/python3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
build:rbe_win --config=rbe
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain"
@ -396,9 +407,7 @@ build:rbe_win --define=override_eigen_strong_inline=true
build:rbe_win --jobs=500
build:rbe_win_py37 --config=rbe
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO="@windows_py37_config_python"
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
build:rbe_win_py38 --config=rbe

View File

@ -1 +1 @@
1.2.1
2.0.0

44
.github/ISSUE_TEMPLATE/00-bug-issue.md vendored Normal file
View File

@ -0,0 +1,44 @@
---
name: Bug Issue
about: Use this template for reporting a bug
labels: 'type:bug'
---
<em>Please make sure that this is a bug. As per our
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template</em>
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
**Other info / logs** Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

View File

@ -1,35 +0,0 @@
---
name: Bug/Performance Issue
about: Use this template for reporting a bug or a performance issue.
---
<em>Please make sure that this is a bug. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template</em>
**System information**
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below):
- Python version:
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate the problem.
**Other info / logs**
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

View File

@ -1,6 +1,7 @@
---
name: Build/Installation Issue
about: Use this template for build/installation issues
labels: 'type:build/install'
---

View File

@ -1,10 +1,11 @@
---
name: Documentation Issue
about: Use this template for documentation related
about: Use this template for documentation related issues
labels: 'type:docs'
---
Thank you for submitting a TensorFlow documentation issue. Per our GitHub
policy, we only address code/doc bugs, performance issues, feature requests, and
build/installation issues on GitHub.

View File

@ -1,6 +1,6 @@
---
name: Feature Request
about: Use this template for raising a feature request
name: Feature Request about: Use this template for raising a feature request
labels: 'type:feature'
---

View File

@ -1,10 +1,10 @@
---
name: TensorFlow Lite Op Request
about: Use this template for reporting ops you are using or missing.
about: Use this template for reporting Lite ops you are using or missing
labels: 'comp:lite'
---
**System information**
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- TensorFlow installed from (source or binary):
@ -17,8 +17,14 @@ about: Use this template for reporting ops you are using or missing.
# Copy and paste here
```
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
Also, please include a link to a GraphDef or the model if possible.
**Any other info / logs**
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
Include any logs or source code that would be helpful to diagnose the problem.
If including tracebacks, please include the full traceback. Large logs and files
should be attached.

View File

@ -1,6 +1,7 @@
---
name: Other Issues
about: Use this template for any other non-support related issues
labels: 'type:others'
---

View File

@ -1,6 +1,7 @@
---
name: TensorFlow Lite New Converter Issue
about: Use this template for reporting issues during model conversion to TFLite.
about: Use this template for reporting issues during model conversion to TFLite
labels: 'TFLiteConverter'
---
@ -12,6 +13,7 @@ about: Use this template for reporting issues during model conversion to TFLite.
**Command used to run the converter or code if youre using the Python API**
If possible, please share a link to Colab/Jupyter/any notebook.
```
# Copy and paste here the exact command

View File

@ -0,0 +1,45 @@
---
name: Performance Issue
about: Use this template for reporting a performance issue
labels: 'type:performance'
---
<em>Please make sure that this is an issue related to performance of TensorFlow.
As per our
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:performance_template</em>
**System information**
- Have I written custom code (as opposed to using a stock
example script provided in TensorFlow):
- OS Platform and Distribution (e.g.,
Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
the issue happens on mobile device:
- TensorFlow installed from (source or
binary): - TensorFlow version (use command below):
- Python version: - Bazel
version (if compiling from source):
- GCC/Compiler version (if compiling from
source):
- CUDA/cuDNN version: - GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
**Other info / logs** Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

1
.gitignore vendored
View File

@ -22,6 +22,7 @@ tensorflow/contrib/cmake/_build/
/tensorflow/python/framework/fast_tensor_util.cpp
/tensorflow/lite/gen/**
/tensorflow/lite/tools/make/downloads/**
/tensorflow/lite/tools/make/gen/**
/api_init_files_list.txt
/estimator_api_init_files_list.txt
*.whl

1
.pylintrc Symbolic link
View File

@ -0,0 +1 @@
tensorflow/tools/ci_build/pylintrc

View File

@ -70,7 +70,7 @@ $ python
3
>>> hello = tf.constant('Hello, TensorFlow!')
>>> hello.numpy()
'Hello, TensorFlow!'
b'Hello, TensorFlow!'
```
For more examples, see the

View File

@ -1,13 +1,11 @@
workspace(name = "org_tensorflow")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//third_party:repo.bzl", "tf_http_archive")
tf_http_archive(
http_archive(
name = "io_bazel_rules_closure",
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
patch_file = "@org_tensorflow//third_party:rules_closure.patch",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
@ -115,3 +113,28 @@ http_archive(
"https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
],
)
# Required for dependency @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
grpc_deps()
load(
"@build_bazel_rules_apple//apple:repositories.bzl",
"apple_rules_dependencies",
)
apple_rules_dependencies()
load(
"@build_bazel_apple_support//lib:repositories.bzl",
"apple_support_dependencies",
)
apple_support_dependencies()
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
bazel_version_repository(name = "bazel_version")

View File

@ -49,8 +49,8 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '1.2.1'
_TF_MAX_BAZEL_VERSION = '1.2.1'
_TF_MIN_BAZEL_VERSION = '2.0.0'
_TF_MAX_BAZEL_VERSION = '2.0.0'
NCCL_LIB_PATHS = [
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
@ -1390,10 +1390,6 @@ def main():
else:
environ_cp['TF_CONFIGURE_IOS'] = '0'
xla_enabled_by_default = is_linux() or is_macos()
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
xla_enabled_by_default, 'xla')
set_action_env_var(
environ_cp,
'TF_NEED_OPENCL_SYCL',

View File

@ -187,6 +187,12 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "fuchsia",
values = {"cpu": "fuchsia"},
visibility = ["//visibility:public"],
)
config_setting(
name = "ios_x86_64",
values = {
@ -448,19 +454,66 @@ config_setting(
visibility = ["//visibility:public"],
)
# Specifies via a config setting if this is a mobile build or not, makes
# it easier to combine settings later.
selects.config_setting_group(
name = "mobile",
match_any = [
":android",
":chromiumos",
":emscripten",
":ios",
],
)
config_setting(
name = "lite_protos_legacy",
values = {"define": "TENSORFLOW_PROTOS=lite"},
visibility = ["//visibility:private"],
)
config_setting(
name = "full_protos",
values = {"define": "TENSORFLOW_PROTOS=full"},
visibility = ["//visibility:public"],
)
selects.config_setting_group(
name = "lite_protos",
match_any = [":lite_protos_legacy"],
)
selects.config_setting_group(
name = "mobile_lite_protos",
match_all = [
":lite_protos",
":mobile",
],
)
selects.config_setting_group(
name = "mobile_full_protos",
match_all = [
":full_protos",
":mobile",
],
)
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
# Instead, please use public APIs or public build rules TF provides.
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
package_group(
name = "internal",
packages = [
# To pass open source testing in the pip Kokoros.
"//bazel_pip/tensorflow/...",
"//learning/brain/swift/x10/...",
"//perftools/accelerators/xprof/api/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
],
)
@ -494,8 +547,8 @@ cc_library(
name = "grpc",
visibility = ["//visibility:public"],
deps = select({
":linux_s390x": ["@grpc//:grpc_unsecure"],
"//conditions:default": ["@grpc"],
":linux_s390x": ["@com_github_grpc_grpc//:grpc_unsecure"],
"//conditions:default": ["@com_github_grpc_grpc//:grpc"],
}),
)
@ -503,8 +556,8 @@ cc_library(
name = "grpc++",
visibility = ["//visibility:public"],
deps = select({
":linux_s390x": ["@grpc//:grpc++_unsecure"],
"//conditions:default": ["@grpc//:grpc++"],
":linux_s390x": ["@com_github_grpc_grpc//:grpc++_unsecure"],
"//conditions:default": ["@com_github_grpc_grpc//:grpc++"],
}),
)
@ -589,7 +642,7 @@ tf_cc_shared_object(
"//tensorflow/core:gpu_runtime_impl",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
"//tensorflow/core:lib_internal_impl",
"//tensorflow/core/profiler/lib:profiler_session_impl",
"//tensorflow/core/profiler:profiler_impl",
"//tensorflow/stream_executor:stream_executor_impl",
"//tensorflow:tf_framework_version_script.lds",
] + tf_additional_binary_deps(),
@ -909,7 +962,6 @@ py_library(
"//conditions:default": [":tf_python_api_gen_v1"],
}) + [
":root_init_gen",
":virtual_root_init_gen",
"//tensorflow/python/keras/api:keras_python_api_gen",
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v1",
"//tensorflow/python/keras/api:keras_python_api_gen_compat_v2",

View File

@ -35,9 +35,11 @@ import inspect as _inspect
import logging as _logging
import os as _os
import site as _site
import six as _six
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# API IMPORTS PLACEHOLDER
@ -69,13 +71,13 @@ except ImportError:
_logging.warning(
"Limited tf.summary API due to missing TensorBoard installation.")
try:
from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from .python.keras.api._v2 import keras
@ -85,6 +87,13 @@ try:
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v2 import estimator
# pylint: enable=g-import-not-at-top
# Enable TF2 behaviors
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top

View File

@ -22,12 +22,14 @@ import distutils as _distutils
import inspect as _inspect
import os as _os
import site as _site
import six as _six
import sys as _sys
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.platform import tf_logging as _logging
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# API IMPORTS PLACEHOLDER
@ -64,13 +66,14 @@ elif _tf_api_dir not in __path__:
# reexport_tf_summary can get compat from sys.modules. Only needed if using
# lazy loading.
_current_module.compat.v2 # pylint: disable=pointless-statement
try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from .python.keras.api._v1 import keras
@ -80,6 +83,13 @@ try:
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v1 import estimator
# pylint: enable=g-import-not-at-top
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
_CONTRIB_WARNING = """

View File

@ -536,6 +536,7 @@ tf_cuda_cc_test(
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:math",
"//tensorflow/core/platform:resource_loader",
],
)
@ -647,12 +648,14 @@ tf_cuda_cc_test(
deps = [
":c_api",
":kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/kernels:ops_testutil",
"//third_party/eigen3",
"@com_google_absl//absl/container:inlined_vector",
],
)

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h"
@ -816,12 +817,15 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
const int num_inputs = input_shapes->num_items;
NodeDef node_def;
node_def.set_name(tfe_op->operation.Name());
node_def.set_op(tfe_op->operation.Name());
node_def.set_name(tfe_op->operation->Name());
node_def.set_op(tfe_op->operation->Name());
for (int i = 0; i < num_inputs; ++i) {
node_def.add_input("dummy_input");
}
tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr());
tensorflow::down_cast<tensorflow::OperationInterface*>(
tfe_op->operation.get())
->Attrs()
.FillAttrValueMap(node_def.mutable_attr());
const tensorflow::OpRegistrationData* op_reg_data;
status->status =

View File

@ -45,6 +45,8 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
@ -193,8 +195,9 @@ TEST(CAPI, LibraryLoadFunctions) {
{
// Load the library.
TF_Status* status = TF_NewStatus();
TF_Library* lib =
TF_LoadLibrary("tensorflow/c/test_op1.so", status);
string lib_path = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("tensorflow", "c", "test_op1.so"));
TF_Library* lib = TF_LoadLibrary(lib_path.c_str(), status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
@ -1350,9 +1353,9 @@ TEST_F(CApiColocationTest, ClearViaProto) {
TEST(CAPI, SavedModel) {
// Load the saved model.
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
const string saved_model_dir = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two", "00000123"));
TF_SessionOptions* opt = TF_NewSessionOptions();
TF_Buffer* run_options = TF_NewBufferFromString("", 0);
TF_Buffer* metagraph = TF_NewBuffer();
@ -1426,9 +1429,9 @@ TEST(CAPI, SavedModel) {
}
TEST(CAPI, SavedModelNullArgsAreValid) {
const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";
const string saved_model_dir = tensorflow::io::JoinPath(
tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);
const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two", "00000123"));
TF_SessionOptions* opt = TF_NewSessionOptions();
TF_Status* s = TF_NewStatus();
const char* tags[] = {tensorflow::kSavedModelTagServe};

View File

@ -28,6 +28,8 @@ tf_cuda_library(
"c_api_debug.cc",
"c_api_experimental.h",
"c_api_internal.h",
"operation_interface.cc",
"operation_interface.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api.h"],
@ -56,6 +58,7 @@ tf_cuda_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:errors",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
@ -92,6 +95,7 @@ filegroup(
srcs = [
"c_api_experimental.h",
"c_api_internal.h",
"operation_interface.h",
"tensor_handle_interface.h",
],
visibility = [
@ -104,6 +108,7 @@ tf_cuda_library(
name = "c_api_internal",
srcs = [
"c_api_experimental.h",
"operation_interface.h",
"tensor_handle_interface.h",
],
hdrs = ["c_api_internal.h"],
@ -128,6 +133,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/container:fixed_array",
],
)
@ -199,6 +205,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform:casts",
"@com_google_absl//absl/strings",
],
)
@ -256,8 +263,6 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/profiler/rpc:profiler_server",
"//tensorflow/core/profiler/rpc/client:capture_profile",
"//tensorflow/core:gpu_runtime",
],
alwayslink = 1,

View File

@ -27,7 +27,6 @@ limitations under the License.
// clang-format on
#include "absl/algorithm/container.h"
#include "absl/container/fixed_array.h"
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
@ -95,14 +94,6 @@ using tensorflow::string;
namespace {
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
const tensorflow::OpDef* op_def = op->operation.OpDef();
if (op_def) return op_def;
status->status =
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
return op_def;
}
bool IsCPU(
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
if (VariantDeviceIsCustom(variant)) {
@ -883,12 +874,12 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
#endif // !IS_MOBILE_PLATFORM
}
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
TF_Status* status) {
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM)
status->status = ctx->context->ClearRemoteExecutors();
status->status = ctx->context->SyncExecutors();
#endif // !IS_MOBILE_PLATFORM
}
@ -1074,6 +1065,10 @@ AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
return new TensorHandleInterface(handle_);
}
void tensorflow::TensorHandleInterface::EnableImplicitMirroring() {
handle_->EnableImplicitMirroring();
}
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
@ -1121,9 +1116,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
return retval;
} else {
tensorflow::Tensor tensor;
if (IsCPU(handle_->device())) {
if (IsCPU(handle_->device()) || handle_->HasLocalMirror(nullptr)) {
const tensorflow::Tensor* src = nullptr;
*status = handle_->Tensor(&src);
if (handle_->HasLocalMirror(nullptr)) {
*status = handle_->TensorFromDevice(nullptr, &src);
} else {
*status = handle_->Tensor(&src);
}
if (!status->ok()) return nullptr;
tensor = *src;
} else {
@ -1131,6 +1130,13 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
CHECK_NE(ctx, nullptr);
*status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
if (!status->ok()) return nullptr;
if (handle_->ImplicitMirroring()) {
*status = handle_->AddEmptyLocalMirror(nullptr);
if (!status->ok()) return nullptr;
Tensor mirror = tensor;
*status = handle_->SetTensor(std::move(mirror), nullptr);
if (!status->ok()) return nullptr;
}
}
return tensorflow::TF_TensorFromTensor(tensor, status);
}
@ -1195,31 +1201,23 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
if (dtype == TF_STRING || dtype == TF_RESOURCE ||
!tensorflow::DataTypeCanUseMemcpy(
static_cast<tensorflow::DataType>(dtype))) {
status->status = tensorflow::errors::InvalidArgument(
"Trying to create a tensor with a pointer to non-pod memory.");
deallocator(data, len, deallocator_arg);
return nullptr;
}
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
// the device?
TF_ManagedBuffer* buf =
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
/*owns_memory=*/false);
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf);
buf->Unref();
tensorflow::TensorHandle* ret_handle;
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant_device;
if (custom_device == nullptr) {
variant_device = device;
status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, device, context, &ret_handle);
} else {
variant_device = custom_device;
status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, custom_device, context, &ret_handle);
}
status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, variant_device, context, &ret_handle);
if (!status->status.ok()) {
return nullptr;
}
@ -1258,9 +1256,8 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
std::unique_ptr<TFE_Op> new_op(
new TFE_Op{tensorflow::EagerOperation(ctx->context)});
status->status =
new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
new TFE_Op{std::make_unique<tensorflow::OperationInterface>(ctx)});
status->status = new_op->operation->Reset(op_or_function_name, nullptr);
if (!status->status.ok()) {
new_op.reset();
}
@ -1270,49 +1267,51 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
void TFE_DeleteOp(TFE_Op* op) { delete op; }
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
status->status = op->operation.SetDeviceName(device_name);
status->status = op->operation->SetDeviceName(device_name);
}
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
tensorflow::Device* device = (op->operation.Device() == nullptr)
? op->operation.EagerContext().HostCPU()
: op->operation.Device();
return device->name().c_str();
return op->operation->DeviceName().c_str();
}
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
op->operation.SetUseXla(enable);
#ifndef TENSORFLOW_EAGER_USE_XLA
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Status s = op->operation->SetUseXla(enable);
if (!s.ok()) {
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
}
#else
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
"built with XLA support.";
#endif // TENSORFLOW_EAGER_USE_XLA
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
tensorflow::TensorHandle* h =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
input->handle.get())
->Handle();
op->operation.AddInput(h);
status->status = op->operation.MaybeInferSingleInputAttrs(h);
status->status = op->operation->AddInput(input->handle);
}
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
num_inputs);
for (int i = 0; i < num_inputs; ++i) {
op->operation.AddInput(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
inputs[i]->handle.get())
->Handle());
handles[i].reset(inputs[i]->handle->Copy());
}
status->status = op->operation.InferInputListAttrs(num_inputs);
status->status = op->operation->AddInputList(handles);
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status) {
TF_AttrType ret = TF_ATTR_INT;
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
attr_name, &ret, is_list);
const tensorflow::AttrTypeMap* attr_types_;
bool is_function;
status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(),
&attr_types_, &is_function);
if (!status->status.ok()) {
return ret;
}
status->status =
tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
return ret;
}
@ -1333,221 +1332,169 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
size_t length) {
op->operation.MutableAttrs()->Set(
attr_name,
tensorflow::StringPiece(static_cast<const char*>(value), length));
auto s = op->operation->SetAttrString(
attr_name, static_cast<const char*>(value), length);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
auto s = op->operation->SetAttrInt(attr_name, value);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
op->operation.MutableAttrs()->Set(attr_name, value);
auto s = op->operation->SetAttrFloat(attr_name, value);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
op->operation.MutableAttrs()->Set(attr_name,
static_cast<tensorflow::DataType>(value));
auto s = op->operation->SetAttrType(attr_name, value);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
const int num_dims, TF_Status* out_status) {
if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"Value specified for `", attr_name, "` has ", num_dims,
" dimensions which is over the limit of ",
tensorflow::TensorShape::MaxDimensions(), ".")
.c_str());
return;
}
tensorflow::TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
op->operation.MutableAttrs()->Set(attr_name, proto);
out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims);
}
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
const TFE_Op* value) {
tensorflow::AttrValue attr_value;
tensorflow::NameAttrList* func = attr_value.mutable_func();
func->set_name(value->operation.Name());
value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
op->operation.MutableAttrs()->Set(attr_name, attr_value);
auto s = op->operation->SetAttrFunction(attr_name, value->operation);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
const char* data, size_t length) {
tensorflow::AttrValue attr_value;
tensorflow::NameAttrList* func = attr_value.mutable_func();
func->set_name(data, length);
op->operation.MutableAttrs()->Set(attr_name, attr_value);
auto s = op->operation->SetAttrFunctionName(attr_name, data, length);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
TF_Status* status) {
tensorflow::Tensor t;
status->status = TF_TensorToTensor(tensor, &t);
if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
status->status = op->operation->SetAttrTensor(attr_name, tensor);
}
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values) {
std::vector<tensorflow::StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
lengths[i]);
auto s =
op->operation->SetAttrStringList(attr_name, values, lengths, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
op->operation.MutableAttrs()->Set(attr_name, v);
}
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
const float* values, int num_values) {
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
auto s = op->operation->SetAttrFloatList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values) {
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
auto s = op->operation->SetAttrIntList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values) {
op->operation.MutableAttrs()->Set(
attr_name,
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
auto s = op->operation->SetAttrTypeList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
}
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
const unsigned char* values, int num_values) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
auto s = op->operation->SetAttrBoolList(attr_name, values, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
}
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, TF_Status* out_status) {
std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
new tensorflow::TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"Value specified for `", attr_name, "` has ", num_dims_i,
" dimensions which is over the limit of ",
tensorflow::TensorShape::MaxDimensions(), ".")
.c_str());
return;
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
proto.get(), num_values));
out_status->status =
op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values);
}
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
const TFE_Op** value, int num_values) {
std::unique_ptr<tensorflow::NameAttrList[]> funcs(
new tensorflow::NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) {
funcs[i].set_name(value[i]->operation.Name());
value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
auto s = op->operation->SetAttrFunctionList(attr_name, value, num_values);
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
funcs.get(), num_values));
}
void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
const void* proto, size_t proto_len,
TF_Status* status) {
tensorflow::AttrValue attr_value;
if (!attr_value.ParseFromArray(proto, proto_len)) {
status->status =
tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
return;
}
if (op == nullptr || op->operation == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Got a null or uninitialized `op` argument");
return;
}
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
op->operation.get());
operation->MutableAttrs()->Set(attr_name, attr_value);
}
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
const char* input_name,
TF_Status* status) {
const tensorflow::OpDef* op_def = GetOpDef(op, status);
if (!status->status.ok()) {
return -1;
}
tensorflow::AttrValueMap attrs;
op->operation.Attrs().FillAttrValueMap(&attrs);
tensorflow::NameRangeMap name_ranges;
status->status = tensorflow::NameRangesForNode(
tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr);
if (!status->status.ok()) {
return -1;
}
auto iter = name_ranges.find(input_name);
if (iter == name_ranges.end()) {
status->status = tensorflow::errors::InvalidArgument("Input '", input_name,
"' not found");
return -1;
}
return iter->second.second - iter->second.first;
int ret = -1;
status->status = op->operation->InputLength(input_name, &ret);
return ret;
}
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
const char* output_name,
TF_Status* status) {
const tensorflow::OpDef* op_def = GetOpDef(op, status);
if (!status->status.ok()) {
return -1;
}
tensorflow::AttrValueMap attrs;
op->operation.Attrs().FillAttrValueMap(&attrs);
tensorflow::NameRangeMap name_ranges;
status->status = tensorflow::NameRangesForNode(
tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges);
if (!status->status.ok()) {
return -1;
}
auto iter = name_ranges.find(output_name);
if (iter == name_ranges.end()) {
status->status = tensorflow::errors::InvalidArgument(
"Output '", output_name, "' not found");
return -1;
}
return iter->second.second - iter->second.first;
int ret = -1;
status->status = op->operation->OutputLength(output_name, &ret);
return ret;
}
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
VLOG(1) << "Calling TFE_Execute() on op " << op;
status->status = tensorflow::EagerExecute(&op->operation,
handle_retvals.data(), num_retvals);
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
*num_retvals);
status->status = op->operation->Execute(&handles, num_retvals);
if (!status->status.ok()) {
return;
}
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
retvals[i] = new TFE_TensorHandle{std::move(handles[i])};
}
}
@ -1675,6 +1622,31 @@ void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
op->operation.get());
*attrs = TFE_OpAttrs(&operation->Attrs(), op->operation->Name().c_str());
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
attrs->attributes->FillAttrValueMap(&m);
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
op->operation.get());
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (auto attribute : m) {
destination->Set(attribute.first, attribute.second);
}
}
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
TF_Status* status) {
tensorflow::NameAttrList name_and_attrs;
attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
name_and_attrs.set_name(attrs->name);
status->status = MessageToBuffer(name_and_attrs, buf);
}
namespace tensorflow {
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,
@ -1794,10 +1766,10 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
op->Inputs()[i])});
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
// TODO(allenl): figure out how to get attrs from EagerOperation
TF_Status status;
TFE_OpAttrs attributes(&op->Attrs(), op->Name().c_str());
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
num_retvals, outputs.data(), &status, info_);
&attributes, num_retvals, outputs.data(), &status, info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(

View File

@ -25,34 +25,20 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include "tensorflow/core/profiler/rpc/profiler_server.h"
using tensorflow::string;
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
status->status = op_to_reset->operation.Reset(
op_or_function_name, raw_device_name, false, nullptr);
status->status =
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr");
}
}
void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
op->operation.ConsumeInput(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle());
}
void TFE_StartProfilerServer(int port) {
// Release child thread intentionally. The child thread can be terminated by
// terminating the main thread.
tensorflow::StartProfilerServer(port).release();
}
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(true);
}
@ -61,46 +47,6 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(false);
}
bool TFE_ProfilerClientStartTracing(const char* service_addr,
const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms,
int num_tracing_attempts,
TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return false;
}
s = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
include_dataset_ops, duration_ms,
num_tracing_attempts);
tensorflow::Set_TF_Status_from_Status(status, s);
return s.ok();
}
void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp,
TF_Buffer* result, TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return;
}
string content;
s = tensorflow::profiler::Monitor(service_addr, duration_ms, monitoring_level,
display_timestamp, &content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
result->data = data;
result->length = content.length();
result->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
tensorflow::Set_TF_Status_from_Status(status, s);
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) {
cell->cell.IncrementBy(value);
@ -568,8 +514,7 @@ void TFE_DeleteCancellationManager(
void TFE_OpSetCancellationManager(TFE_Op* op,
TFE_CancellationManager* cancellation_manager,
TF_Status* status) {
op->operation.SetCancellationManager(
&cancellation_manager->cancellation_manager);
status->status = op->operation->SetCancellationManager(cancellation_manager);
}
TFE_Executor* TFE_NewExecutor(bool is_async) {
@ -611,3 +556,28 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
tensorflow::port::Free(data);
};
}
void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
TF_Status* status) {
h->handle->EnableImplicitMirroring();
status->status = tensorflow::Status::OK();
}
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) {
auto* function_def = ctx->context->FindFunctionDef(function_name);
if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound(
"Unable to find FunctionDef with name: ", function_name);
return;
}
string str = function_def->SerializeAsString();
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
buf->data = data;
buf->length = str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
status->status = tensorflow::Status::OK();
}

View File

@ -34,18 +34,6 @@ TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
const char* raw_device_name,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
TF_Status* status);
// Start a profiler grpc server which listens to specified port. It will start
// the server on its own thread. It can be shutdown by terminating tensorflow.
// It can be used in both Eager mode and graph mode. Creating multiple profiler
// server is allowed. The service defined in
// tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture trace file
// following https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern void TFE_StartProfilerServer(int port);
// Enables only graph collection in RunMetadata on the functions executed from
// this context.
TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
@ -54,29 +42,6 @@ TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
// this context.
TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx);
// Send a grpc request to profiler server (service_addr) to perform on-demand
// profiling and save the result into logdir which can be visualized by
// TensorBoard. worker_list is the list of worker TPUs separated by ','. Set
// include_dataset_opts to false to profile longer traces. It will block the
// caller thread until receives tracing result.
// This API is designed for TensorBoard, for end user, please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
const char* service_addr, const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
TF_Status* status);
// Send a grpc request to profiler server (service_addr) to perform on-demand
// monitoring and return the result in a string. It will block the
// caller thread until receiving the monitoring result.
// This API is designed for TensorBoard, for end user, please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern void TFE_ProfilerClientMonitor(
const char* service_addr, int duration_ms, int monitoring_level,
bool display_timestamp, TF_Buffer* result, TF_Status* status);
// TODO(fishx): Move these monitoring APIs into a separate file.
// -----------------------------------------------------------------------------
// Monitoring Counter APIs.
@ -417,9 +382,17 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
const char* worker_name,
TF_Status* status);
// Clear pending streaming requests and error statuses on remote executors.
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
TF_Status* status);
// Sync pending nodes in local executors (including the context default executor
// and thread executors) and streaming requests to remote executors, and get the
// combined status.
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
TF_Status* status);
// If the TensorHandle is copied to another device as part of an op execution,
// the copy is destroyed after the op has executed. Enabling implicit mirroring
// causes the copy to be held as a mirror for the lifetime of the TensorHandle.
TF_CAPI_EXPORT extern void TFE_TensorHandleEnableImplicitMirroring(
TFE_TensorHandle*, TF_Status*);
// This function will block till the operation that produces `h` has
// completed. This is only valid on local TFE_TensorHandles. The pointer
@ -450,7 +423,42 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
TF_Buffer* buf);
#define TFE_CUSTOM_DEVICE_VERSION 0
// APIs for generically dealing with op attributes (e.g. when forwarding them
// through custom device implementations).
//
// TODO(allenl): Currently these are black boxes, but we should have some way to
// inspect values. This would let people e.g. copy over most attributes and then
// modify some based on their values.
// A reference to an op's name -> attribute mapping
typedef struct TFE_OpAttrs TFE_OpAttrs;
// Fetch a struct with a reference to information about attributes of `op`.
//
// The `attrs` struct does not own any memory, and `op` must outlive it.
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
// Add attributes in `attrs` to `op`.
//
// Does not overwrite or update existing attributes, but adds new ones.
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
// Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`),
// containing the op name and a map of its attributes.
TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs,
TF_Buffer* buf,
TF_Status* status);
// Set an op's attribute from a serialized AttrValue protocol buffer.
//
// Analogous to TF_SetAttrValueProto for building graph operations.
TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
const char* attr_name,
const void* proto,
size_t proto_len,
TF_Status* status);
#define TFE_CUSTOM_DEVICE_VERSION 1
// Struct to be filled in
typedef struct TFE_CustomDevice {
@ -467,10 +475,10 @@ typedef struct TFE_CustomDevice {
void* device_info);
// Method to execute an operation.
// TODO(allenl) figure out a generic way of passing attrs here
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
const char* operation_name, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
const char* operation_name, const TFE_OpAttrs* attributes,
int* num_outputs, TFE_TensorHandle** outputs, TF_Status* s,
void* device_info);
// Method to delete a device.
void (*delete_device)(void* device_info);
@ -501,6 +509,11 @@ typedef struct TFE_CustomDevice {
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info);
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
const char* function_name,
TF_Buffer* buf,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -27,12 +27,12 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
@ -89,7 +89,7 @@ struct TFE_TensorDebugInfo {
};
struct TFE_Op {
tensorflow::EagerOperation operation;
std::unique_ptr<AbstractOperationInterface> operation;
};
struct TFE_MonitoringCounterCell {
@ -236,4 +236,17 @@ struct TFE_Executor {
tensorflow::EagerExecutor* unowned_executor;
};
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
// that sometimes do not require serialization.
struct TFE_OpAttrs {
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
const char* op_name)
: name(op_name), attributes(value) {}
const char* name;
const tensorflow::AttrBuilder* attributes;
};
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
@ -127,7 +128,7 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
void TestRemoteExecuteSilentCopies(bool async) {
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
@ -166,10 +167,14 @@ void TestRemoteExecuteSilentCopies(bool async) {
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandleEnableImplicitMirroring(h1_task2, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Handles are on task0 (local), and task2, but op is on task1.
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
TFE_OpSetDevice(matmul, task1_name, status);
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
}
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
@ -177,6 +182,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!async) {
auto remote_arg = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
h1_task2->handle.get())
->Handle();
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());
// The input handles should never change since they have been mirrored.
ASSERT_EQ(op->GetInput(1), remote_arg);
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -213,9 +229,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
worker_server2.release();
}
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
TEST(CAPI, RemoteExecuteSilentCopies) {
TestRemoteExecuteSilentCopies(false, true);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true);
TestRemoteExecuteSilentCopies(true, true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
TestRemoteExecuteSilentCopies(false, false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
TestRemoteExecuteSilentCopies(true, false);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {

View File

@ -17,12 +17,15 @@ limitations under the License.
#include <string.h>
#include <string>
#include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
@ -365,7 +368,8 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
void TensorHandleSilentCopy(bool async,
TFE_ContextDevicePlacementPolicy global_policy,
TFE_ContextDevicePlacementPolicy thread_policy) {
TFE_ContextDevicePlacementPolicy thread_policy,
bool cpu_op) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -376,26 +380,49 @@ void TensorHandleSilentCopy(bool async,
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx, thread_policy);
}
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
// Disable the test if no GPU is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
if (cpu_op) {
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status.get());
} else {
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
}
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
// Validate if the input was replaced with a different TensorHandle
auto arg0 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
hcpu->handle.get())
->Handle();
auto arg1 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
hgpu->handle.get())
->Handle();
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());
// The input handles should never change since they have been mirrored.
EXPECT_EQ(op->GetInput(0), arg0);
EXPECT_EQ(op->GetInput(1), arg1);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
@ -411,19 +438,19 @@ void TensorHandleSilentCopy(bool async,
}
TEST(CAPI, TensorHandleSilentCopy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT);
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyAsync) {
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT);
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
TFE_DEVICE_PLACEMENT_SILENT);
TFE_DEVICE_PLACEMENT_SILENT, false);
}
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
TFE_DEVICE_PLACEMENT_SILENT);
TFE_DEVICE_PLACEMENT_SILENT, false);
}
void SetAndGetOpDevices(bool async) {
@ -559,6 +586,91 @@ TEST(CAPI, TensorHandleDevices) {
TFE_DeleteContext(ctx);
}
void ExecuteAdd(bool async, bool forward_input) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* n = TestMatrixTensorHandle100x100();
// If a GPU exists, copy the handle to GPU so that we can exercise
// unprotecting a mirror.
std::string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* n_gpu =
TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandleEnableImplicitMirroring(n_gpu, status);
TFE_DeleteTensorHandle(n);
n = n_gpu;
}
TFE_TensorHandle* m = TestMatrixTensorHandle100x100();
// Store pointer to raw buffer for validation of forwarding behaviour.
TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
void* orig_ptr = TF_TensorData(orig);
TF_DeleteTensor(orig);
TFE_Op* add_op = AddOp(ctx, n, m);
std::string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(add_op, cpu_device_name.c_str(), status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
if (forward_input) {
TFE_DeleteTensorHandle(n);
}
int num_retvals = 1;
if (async) {
// Enqueue dummy ops so we backlog async execution & actually test async.
for (int i = 0; i < 10000; ++i) {
TFE_TensorHandle* dummy = nullptr;
TFE_Execute(add_op, &dummy, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(dummy);
}
}
TFE_TensorHandle* retval = nullptr;
TFE_Execute(add_op, &retval, &num_retvals, status);
EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
if (!forward_input) {
TFE_DeleteTensorHandle(n);
}
TFE_DeleteOp(add_op);
TF_Tensor* t = TFE_TensorHandleResolve(retval, status);
if (forward_input || async) {
EXPECT_EQ(orig_ptr, TF_TensorData(t));
} else {
EXPECT_NE(orig_ptr, TF_TensorData(t));
}
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retval);
TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float result[100 * 100] = {0};
EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
for (int i = 0; i < 100 * 100; ++i) {
EXPECT_EQ(2.0f, result[i]);
}
TF_DeleteStatus(status);
}
TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -1197,6 +1309,14 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
TFE_DeleteTensorHandle(h_shares_tensor);
}
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
tensorflow::AttrValueMap attr_values;
tensorflow::down_cast<tensorflow::OperationInterface*>(op->operation.get())
->Attrs()
.FillAttrValueMap(&attr_values);
return attr_values;
}
TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -1213,8 +1333,7 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
TFE_OpAddInput(minOp, axis, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
minOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(minOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
@ -1253,8 +1372,7 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
TFE_OpAddInputList(concatOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
@ -1294,8 +1412,7 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
TFE_OpAddInputList(assertOp, data, 3, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
assertOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(assertOp);
tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
EXPECT_NE(attr_found, attr_values.cend());
EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL);
@ -1331,16 +1448,15 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInput(concatOp, dim, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CHECK(concatOp->operation.OpDef());
CHECK(concatOp->operation->OpDef());
TFE_OpAddInput(concatOp, inputs[0], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_FALSE(concatOp->operation.OpDef())
EXPECT_FALSE(concatOp->operation->OpDef())
<< "Inference context is still present";
TFE_OpAddInput(concatOp, inputs[1], status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
concatOp->operation.Attrs().FillAttrValueMap(&attr_values);
tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
EXPECT_EQ(attr_values.find("T"), attr_values.end());
EXPECT_EQ(attr_values.find("N"), attr_values.end());
@ -1427,4 +1543,88 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
TFE_DeleteContext(ctx);
}
TEST(CAPI, TestTFE_OpGetAttrs) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
TFE_OpAttrs attributes;
TFE_OpGetAttrs(var_op, &attributes);
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
TFE_OpAddAttrs(copy_op, &attributes);
unsigned char is_list = 0;
ASSERT_EQ(TF_ATTR_TYPE,
TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(TF_ATTR_SHAPE,
TFE_OpGetAttrType(copy_op, "shape", &is_list, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
copy_op->operation.get());
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
TF_DeleteStatus(status);
TFE_DeleteOp(var_op);
TFE_DeleteOp(copy_op);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TestTFE_OpAttrsSerialize) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
TFE_OpAttrs attributes;
TFE_OpGetAttrs(var_op, &attributes);
TF_Buffer* serialized_attr_values = TF_NewBuffer();
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::NameAttrList name_and_attrs;
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
serialized_attr_values->length));
ASSERT_EQ("VarHandleOp", name_and_attrs.name());
ASSERT_EQ(tensorflow::DT_INT64,
name_and_attrs.attr().find("dtype")->second.type());
TF_DeleteBuffer(serialized_attr_values);
TFE_Op* second_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
string serialized_dtype;
ASSERT_TRUE(name_and_attrs.attr().find("dtype")->second.SerializeToString(
&serialized_dtype));
TFE_OpSetAttrValueProto(
second_var_op, "dtype",
reinterpret_cast<const void*>(serialized_dtype.c_str()),
serialized_dtype.length(), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::AttrValueMap attr_values;
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
second_var_op->operation.get());
op->Attrs().FillAttrValueMap(&attr_values);
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
TF_DeleteStatus(status);
TFE_DeleteOp(var_op);
TFE_DeleteOp(second_var_op);
TFE_DeleteContext(ctx);
}
} // namespace

View File

@ -131,6 +131,21 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2() {
return th;
}
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "AddV2", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, a, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, b, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
return op;
}
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();

View File

@ -42,6 +42,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
// Return a tensor handle containing a 3x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle3X2();
// Return an add op multiplying `a` by `b`.
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
// Return a matmul op multiplying `a` by `b`.
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
@ -31,6 +32,8 @@ struct LoggingDevice {
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
};
struct LoggedTensor {
@ -81,12 +84,14 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
}
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
const char* operation_name, int* num_outputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
@ -115,6 +120,7 @@ void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
@ -122,7 +128,7 @@ void DeleteLoggingDevice(void* device_info) {
}
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag) {
bool* arrived_flag, bool* executed_flag) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
@ -131,6 +137,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name,
LoggingDevice* device = new LoggingDevice;
device->ctx = context;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device);
@ -144,13 +151,15 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
TFE_DeleteContextOptions(opts);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, &arrived);
RegisterLoggingDevice(context, name, &arrived, &executed);
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
ASSERT_FALSE(arrived);
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
ASSERT_TRUE(arrived);
ASSERT_FALSE(executed);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
MatMulOp(context, hcpu, hdevice), TFE_DeleteOp);
@ -160,6 +169,7 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
int num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
TFE_DeleteTensorHandle(retval);
TFE_DeleteTensorHandle(hcpu);
@ -167,4 +177,118 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
TFE_DeleteContext(context);
}
TEST(CUSTOM_DEVICE, ResetOperation) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts, status.get()), TFE_DeleteContext);
TFE_DeleteContextOptions(opts);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* custom_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed);
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
TFE_OpReset(reused_op.get(), "Identity", custom_device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
tensorflow::string(custom_device_name));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpReset(reused_op.get(), "Identity",
"/job:localhost/replica:0/task:0/device:CPU:0", status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
tensorflow::string("/job:localhost/replica:0/task:0/device:CPU:0"));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, MakeVariable) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed);
// Create a variable handle placed on the custom device.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
TFE_OpSetAttrString(op.get(), "container", "", 0);
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
executed = false;
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpAddInput(op.get(), one.get(), status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto value_cleaner = tensorflow::gtl::MakeCleanup(
[var_value]() { TFE_DeleteTensorHandle(var_value); });
ASSERT_EQ(tensorflow::string(name),
tensorflow::string(
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
TFE_TensorHandle* var_value_unpacked =
reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(var_value, status.get()))
->tensor;
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
} // namespace

View File

@ -0,0 +1,312 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/operation_interface.h"
#include "absl/container/fixed_array.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
OperationInterface::OperationInterface(TFE_Context* ctx)
: operation_(ctx->context) {}
const string& OperationInterface::DeviceName() const {
absl::variant<Device*, CustomDevice*> variant_device =
(operation_.Device() == kVariantDeviceNull)
? operation_.EagerContext().HostCPU()
: operation_.Device();
return absl::visit([](auto* d) -> const string& { return d->name(); },
variant_device);
}
Status OperationInterface::SetDeviceName(const char* name) {
return operation_.SetDeviceName(name);
}
Status OperationInterface::SetAttrString(const char* attr_name,
const char* data, size_t length) {
operation_.MutableAttrs()->Set(attr_name, StringPiece(data, length));
return Status::OK();
}
Status OperationInterface::SetAttrInt(const char* attr_name, int64_t value) {
operation_.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
return Status::OK();
}
Status OperationInterface::SetAttrFloat(const char* attr_name, float value) {
operation_.MutableAttrs()->Set(attr_name, value);
return Status::OK();
}
Status OperationInterface::SetAttrBool(const char* attr_name, bool value) {
operation_.MutableAttrs()->Set(attr_name, value);
return Status::OK();
}
Status OperationInterface::SetAttrType(const char* attr_name,
TF_DataType value) {
operation_.MutableAttrs()->Set(attr_name, static_cast<DataType>(value));
return Status::OK();
}
Status OperationInterface::SetAttrShape(const char* attr_name,
const int64_t* dims,
const int num_dims) {
if (num_dims > TensorShape::MaxDimensions()) {
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
num_dims,
" dimensions which is over the limit of ",
TensorShape::MaxDimensions(), ".");
}
TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
operation_.MutableAttrs()->Set(attr_name, proto);
return Status::OK();
}
Status OperationInterface::SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) {
AttrValue attr_value;
NameAttrList* func = attr_value.mutable_func();
func->set_name(value->Name());
OperationInterface* value_operation =
tensorflow::down_cast<OperationInterface*>(value.get());
value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
operation_.MutableAttrs()->Set(attr_name, attr_value);
return Status::OK();
}
Status OperationInterface::SetAttrFunctionName(const char* attr_name,
const char* data,
size_t length) {
AttrValue attr_value;
NameAttrList* func = attr_value.mutable_func();
func->set_name(data, length);
operation_.MutableAttrs()->Set(attr_name, attr_value);
return Status::OK();
}
Status OperationInterface::SetAttrTensor(const char* attr_name,
TF_Tensor* tensor) {
Tensor t;
TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
operation_.MutableAttrs()->Set(attr_name, t);
return Status::OK();
}
Status OperationInterface::SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths,
int num_values) {
std::vector<StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
}
operation_.MutableAttrs()->Set(attr_name, v);
return Status::OK();
}
Status OperationInterface::SetAttrFloatList(const char* attr_name,
const float* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const float>(values, num_values));
return Status::OK();
}
Status OperationInterface::SetAttrIntList(const char* attr_name,
const int64_t* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrTypeList(const char* attr_name,
const TF_DataType* values,
int num_values) {
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const DataType>(
reinterpret_cast<const DataType*>(values), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const bool>(b.get(), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims,
int num_values) {
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > TensorShape::MaxDimensions()) {
return errors::InvalidArgument(
strings::StrCat("Value specified for `", attr_name, "` has ",
num_dims_i, " dimensions which is over the limit of ",
TensorShape::MaxDimensions(), "."));
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
return Status::OK();
}
Status OperationInterface::SetAttrFunctionList(const char* attr_name,
const TFE_Op** value,
int num_values) {
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) {
auto value_operation =
tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
funcs[i].set_name(value_operation->operation_.Name());
value_operation->operation_.Attrs().FillAttrValueMap(
funcs[i].mutable_attr());
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
return Status::OK();
}
const OpDef* OperationInterface::GetOpDef(Status* status) {
const tensorflow::OpDef* op_def = operation_.OpDef();
if (op_def) return op_def;
*status = OpDefForOp(Name(), &op_def);
return op_def;
}
Status OperationInterface::InputLength(const char* input_name, int* length) {
Status status;
const tensorflow::OpDef* op_def = GetOpDef(&status);
if (!status.ok()) {
return status;
}
AttrValueMap attrs;
operation_.Attrs().FillAttrValueMap(&attrs);
NameRangeMap name_ranges;
TF_RETURN_IF_ERROR(
NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
auto iter = name_ranges.find(input_name);
if (iter == name_ranges.end()) {
return errors::InvalidArgument("Input '", input_name, "' not found");
}
*length = iter->second.second - iter->second.first;
return Status::OK();
}
Status OperationInterface::OutputLength(const char* output_name, int* length) {
Status status;
const tensorflow::OpDef* op_def = GetOpDef(&status);
if (!status.ok()) {
return status;
}
AttrValueMap attrs;
operation_.Attrs().FillAttrValueMap(&attrs);
NameRangeMap name_ranges;
TF_RETURN_IF_ERROR(
NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
auto iter = name_ranges.find(output_name);
if (iter == name_ranges.end()) {
return errors::InvalidArgument("Output '", output_name, "' not found");
}
*length = iter->second.second - iter->second.first;
return Status::OK();
}
Status OperationInterface::AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
TensorHandle* h =
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
operation_.AddInput(h);
return operation_.MaybeInferSingleInputAttrs(h);
}
Status OperationInterface::AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) {
for (auto& input : inputs) {
TensorHandle* h =
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
operation_.AddInput(h);
}
return operation_.InferInputListAttrs(inputs.size());
}
Status OperationInterface::Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) {
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
TF_RETURN_IF_ERROR(
EagerExecute(&operation_, handle_retvals.data(), num_retvals));
for (int i = 0; i < *num_retvals; ++i) {
retvals->at(i).reset(
new tensorflow::TensorHandleInterface(handle_retvals[i]));
}
return Status::OK();
}
Status OperationInterface::SetCancellationManager(
TFE_CancellationManager* cancellation_manager) {
operation_.SetCancellationManager(
&cancellation_manager->cancellation_manager);
return Status::OK();
}
Status OperationInterface::SetUseXla(bool enable) {
operation_.SetUseXla(enable);
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,188 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#include <memory>
#include "absl/container/fixed_array.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
// Abstract interface to an operation.
class AbstractOperationInterface {
public:
virtual ~AbstractOperationInterface() {}
virtual void Clear() = 0;
virtual tensorflow::Status Reset(const char* op,
const char* raw_device_name) = 0;
virtual const tensorflow::string& Name() const = 0;
virtual const tensorflow::string& DeviceName() const = 0;
virtual tensorflow::Status SetDeviceName(const char* name) = 0;
virtual tensorflow::Status AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
virtual tensorflow::Status AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) = 0;
virtual tensorflow::Status Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
virtual tensorflow::Status SetAttrString(const char* attr_name,
const char* data, size_t length) = 0;
virtual tensorflow::Status SetAttrInt(const char* attr_name,
int64_t value) = 0;
virtual tensorflow::Status SetAttrFloat(const char* attr_name,
float value) = 0;
virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0;
virtual tensorflow::Status SetAttrType(const char* attr_name,
TF_DataType value) = 0;
virtual tensorflow::Status SetAttrShape(const char* attr_name,
const int64_t* dims,
const int num_dims) = 0;
virtual tensorflow::Status SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
virtual tensorflow::Status SetAttrFunctionName(const char* attr_name,
const char* value,
size_t length) = 0;
virtual tensorflow::Status SetAttrTensor(const char* attr_name,
TF_Tensor* tensor) = 0;
virtual tensorflow::Status SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths,
int num_values) = 0;
virtual tensorflow::Status SetAttrFloatList(const char* attr_name,
const float* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrIntList(const char* attr_name,
const int64_t* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrTypeList(const char* attr_name,
const TF_DataType* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims,
int num_values) = 0;
virtual tensorflow::Status SetAttrFunctionList(const char* attr_name,
const TFE_Op** value,
int num_values) = 0;
virtual tensorflow::Status InputLength(const char* input_name,
int* length) = 0;
virtual tensorflow::Status OutputLength(const char* output_name,
int* length) = 0;
// Experimental
virtual tensorflow::Status SetUseXla(bool enable) {
return tensorflow::errors::Unimplemented("SetUseXla not implemented");
}
virtual tensorflow::Status SetCancellationManager(
TFE_CancellationManager* cancellation_manager) {
return tensorflow::errors::Unimplemented(
"SetCancellationManager not implemented");
}
};
namespace tensorflow {
class OpDef;
class OperationInterface : public AbstractOperationInterface {
public:
explicit OperationInterface(TFE_Context* ctx);
~OperationInterface() override{};
void Clear() override { operation_.Clear(); }
Status Reset(const char* op, const char* raw_device_name) override {
return operation_.Reset(op, raw_device_name, false, nullptr);
}
const string& Name() const override { return operation_.Name(); }
const string& DeviceName() const override;
Status SetDeviceName(const char* name) override;
Status AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) override;
Status AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) override;
Status Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) override;
const tensorflow::OpDef* OpDef() const override {
return operation_.OpDef();
};
Status SetAttrString(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrInt(const char* attr_name, int64_t value) override;
Status SetAttrFloat(const char* attr_name, float value) override;
Status SetAttrBool(const char* attr_name, bool value) override;
Status SetAttrType(const char* attr_name, TF_DataType value) override;
Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override;
Status SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) override;
Status SetAttrFunctionName(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override;
Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override;
Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override;
Status SetAttrTypeList(const char* attr_name, const TF_DataType* values,
int num_values) override;
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override;
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override;
Status SetAttrFunctionList(const char* attr_name, const TFE_Op** value,
int num_values) override;
Status InputLength(const char* input_name, int* length) override;
Status OutputLength(const char* output_name, int* length) override;
Status SetUseXla(bool enable) override;
Status SetCancellationManager(
TFE_CancellationManager* cancellation_manager) override;
// TODO(gjn): Remove once TFE_InferShapes is removed
const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); }
tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
private:
const tensorflow::OpDef* GetOpDef(Status* status);
EagerOperation operation_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_

View File

@ -55,6 +55,14 @@ class AbstractTensorHandleInterface {
// Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0;
// Maintain mirror tensors for any implicit copies to local devices. This
// setting is offered on a per tensor handle basis to avoid potential memory
// over utilization due to holding on to mirrors as well as the original
// tensor. Note this setting overrides the context mirroring policy whereby if
// the mirroring policy is MIRRORING_NONE, we will still continue to mirror
// this tensor.
virtual void EnableImplicitMirroring() = 0;
};
namespace tensorflow {
@ -77,6 +85,8 @@ class TensorHandleInterface : public AbstractTensorHandleInterface {
AbstractTensorHandleInterface* Copy() override;
void EnableImplicitMirroring() override;
// TODO(gjn): This is not a very generic interface, but is needed for specific
// use cases.
TensorHandle* Handle() { return handle_; }

View File

@ -24,12 +24,16 @@ limitations under the License.
#include <memory>
#include <string>
#include <utility>
#include "absl/container/inlined_vector.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include <memory>
#include <vector>
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
@ -64,25 +65,41 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
}
} // namespace tensorflow
namespace {
TF_Tensor* CreateTensor(TF_ManagedBuffer* buf, TF_DataType dtype,
const int64_t* dims, int num_dims, size_t len) {
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
// TODO(gjn): Make the choice of interface a compile-time configuration.
tensorflow::TensorInterface ret(
Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf));
buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
return nullptr;
}
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
}
} // namespace
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
int num_dims, size_t len) {
void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len,
tensorflow::cpu_allocator());
return TF_NewTensor(dtype, dims, num_dims, data, len,
tensorflow::deallocate_buffer,
tensorflow::cpu_allocator());
TF_ManagedBuffer* buf =
new TF_ManagedBuffer(data, len, tensorflow::deallocate_buffer,
tensorflow::cpu_allocator(), /*owns_memory=*/true);
return CreateTensor(buf, dtype, dims, num_dims, len);
}
TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg) {
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
TF_ManagedBuffer* buf = nullptr;
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
tensorflow::DataTypeCanUseMemcpy(
@ -97,24 +114,17 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
// Other types have the same representation, so copy only if it is safe to
// do so.
buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len),
len, tensorflow::deallocate_buffer, nullptr);
len, tensorflow::deallocate_buffer, nullptr,
/*owns_memory=*/true);
std::memcpy(buf->data(), data, len);
// Free the original buffer.
deallocator(data, len, deallocator_arg);
} else {
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
/*owns_memory=*/false);
}
// TODO(gjn): Make the choice of interface a compile-time configuration.
tensorflow::TensorInterface ret(
Tensor(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf));
buf->Unref();
size_t elem_size = TF_DataTypeSize(dtype);
if (elem_size > 0 && len < (elem_size * ret.NumElements())) {
return nullptr;
}
return new TF_Tensor{std::make_unique<tensorflow::TensorInterface>(ret)};
return CreateTensor(buf, dtype, dims, num_dims, len);
}
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* t) {

View File

@ -38,11 +38,12 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
public:
TF_ManagedBuffer(void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg)
void* deallocator_arg, bool owns_memory)
: TensorBuffer(data),
len_(len),
deallocator_(deallocator),
deallocator_arg_(deallocator_arg) {}
deallocator_arg_(deallocator_arg),
owns_memory_(owns_memory) {}
~TF_ManagedBuffer() override {
(*deallocator_)(data(), len_, deallocator_arg_);
@ -57,13 +58,13 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer {
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
}
// Prevents input forwarding from mutating this buffer.
bool OwnsMemory() const override { return false; }
bool OwnsMemory() const override { return owns_memory_; }
private:
const size_t len_;
void (*const deallocator_)(void* data, size_t len, void* arg);
void* const deallocator_arg_;
bool owns_memory_;
};
namespace tensorflow {

View File

@ -68,6 +68,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/platform:resource_loader",
],
)
@ -224,3 +225,15 @@ filegroup(
"testdata/VarsAndArithmeticObjectGraph/**",
]),
)
exports_files(
glob([
"testdata/half_plus_two_pbtxt/**",
"testdata/half_plus_two_main_op/**",
"testdata/half_plus_two/**",
"testdata/half_plus_two_v2/**",
"testdata/x_plus_y_v2_debuginfo/**",
"testdata/CyclicModule/**",
"testdata/VarsAndArithmeticObjectGraph/**",
]),
)

View File

@ -21,15 +21,22 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
constexpr char kTestDataPbTxt[] =
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
constexpr char kTestDataSharded[] =
"cc/saved_model/testdata/half_plus_two/00000123";
string TestDataPbTxt() {
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two_pbtxt", "00000123");
}
string TestDataSharded() {
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
"half_plus_two", "00000123");
}
class ReaderTest : public ::testing::Test {
protected:
@ -49,8 +56,7 @@ class ReaderTest : public ::testing::Test {
TEST_F(ReaderTest, TagMatch) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def));
CheckMetaGraphDef(meta_graph_def);
@ -59,8 +65,7 @@ TEST_F(ReaderTest, TagMatch) {
TEST_F(ReaderTest, NoTagMatch) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
&meta_graph_def);
EXPECT_FALSE(st.ok());
@ -73,8 +78,7 @@ TEST_F(ReaderTest, NoTagMatch) {
TEST_F(ReaderTest, NoTagMatchMultiple) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
Status st = ReadMetaGraphDefFromSavedModel(
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
EXPECT_FALSE(st.ok());
@ -87,8 +91,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
TEST_F(ReaderTest, PbtxtFormat) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def));
CheckMetaGraphDef(meta_graph_def);
@ -97,8 +100,7 @@ TEST_F(ReaderTest, PbtxtFormat) {
TEST_F(ReaderTest, InvalidExportPath) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
const string export_dir = GetDataDependencyFilepath("missing-path");
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def);
EXPECT_FALSE(st.ok());

View File

@ -20,9 +20,11 @@ from __future__ import print_function as _print_function
import logging as _logging
import os as _os
import six as _six
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# pylint: disable=g-bad-import-order
@ -36,20 +38,19 @@ try:
from tensorboard.summary._tf import summary
_current_module.__path__ = (
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
# Make sure we get the correct summary module with lazy loading
setattr(_current_module, "summary", summary)
except ImportError:
_logging.warning(
"Limited tf.compat.v2.summary API due to missing TensorBoard "
"installation.")
try:
from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from tensorflow.python.keras.api._v2 import keras
@ -59,6 +60,13 @@ try:
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v2 import estimator
# pylint: enable=g-import-not-at-top
# We would like the following to work for fully enabling 2.0 in a 1.0 install:
#

View File

@ -20,8 +20,10 @@ from __future__ import print_function as _print_function
import os as _os
import sys as _sys
import six as _six
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# pylint: disable=g-bad-import-order
@ -31,13 +33,14 @@ from tensorflow.python.tools import module_util as _module_util
# Hook external TensorFlow modules.
_current_module = _sys.modules[__name__]
try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from tensorflow.python.keras.api._v1 import keras
@ -47,6 +50,14 @@ try:
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v1 import estimator
# pylint: enable=g-import-not-at-top
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
_current_module.app.flags = flags # pylint: disable=undefined-variable

View File

@ -84,6 +84,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support", # fixdeps: keep
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/codegen.h"
#include <algorithm>
#include <string>
#include <vector>
@ -29,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -139,23 +141,40 @@ TEST_F(ParseCppClassTest, ParseFail) {
static void CompareWithGoldenFile(
const string& tensorflow_relative_golden_file_name,
const string& expected_contents) {
const string& expected_contents, bool ignore_cr) {
// Get rid of all CR characters, we may be running under windows.
string sanitized_expected_contents(expected_contents);
if (ignore_cr) {
sanitized_expected_contents.erase(
std::remove(sanitized_expected_contents.begin(),
sanitized_expected_contents.end(), '\r'),
sanitized_expected_contents.end());
}
// To update the golden file, flip update_golden to true and run the
// following:
// bazel test --test_strategy=local \
// third_party/tensorflow/compiler/aot:codegen_test
const bool update_golden = false;
const string golden_file_name = io::JoinPath(
testing::TensorFlowSrcRoot(), tensorflow_relative_golden_file_name);
string golden_file_name;
if (update_golden) {
golden_file_name = io::JoinPath(testing::TensorFlowSrcRoot(),
tensorflow_relative_golden_file_name);
TF_EXPECT_OK(
WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
}
golden_file_name =
GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
string golden_file_contents;
TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
&golden_file_contents));
if (ignore_cr) {
golden_file_contents.erase(std::remove(golden_file_contents.begin(),
golden_file_contents.end(), '\r'),
golden_file_contents.end());
}
EXPECT_EQ(golden_file_contents, expected_contents);
}
@ -229,14 +248,18 @@ TEST(CodegenTest, Golden) {
// The other fields in metadata_result are tested as part of the generated
// header test.
CompareWithGoldenFile("compiler/aot/codegen_test_o.golden",
metadata_result.object_file_data);
// This specific golden test checks a binary file. It can potentially run into
// issues due to ABIs not being stable, but has not so far.
// If we see any ABI issues, we should reconsider this specific test case.
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_o.golden",
metadata_result.object_file_data, false);
string header;
TF_ASSERT_OK(
GenerateHeader(opts, config, compile_result, metadata_result, &header));
CompareWithGoldenFile("compiler/aot/codegen_test_h.golden", header);
CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_h.golden", header,
true);
}
} // namespace
} // namespace tfcompile

View File

@ -14,6 +14,10 @@ package_group(
includes = [
"//tensorflow/compiler/tf2xla:internal",
],
packages = [
"//tensorflow/compiler/tests/...",
"//tensorflow/python/...",
],
)
package_group(

View File

@ -676,12 +676,10 @@ Status Encapsulator::Subgraph::AddFunctionCallNode(
Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const {
AttrSlice attrs = node->attrs();
attr->clear();
bool found_group_attribute = false;
for (const auto& node_attr : attrs) {
if (node_attr.first == group_attribute_) {
TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
*attr = node_attr.second.s();
found_group_attribute = true;
break;
}
}
@ -790,7 +788,6 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images));
TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs));
MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
for (auto& entry : subgraphs_) {

View File

@ -108,7 +108,7 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
"(LRN, LRNGrad)."
" BN: TF FusedBatchNorm* operations."
" FUSIBLE: All TF operations that XLA can fuse (All the above). "
"You can also put any TF operation name, e.g. 'FUSIBLE,Matmul'."),
"You can also put any TF operation name, e.g. 'FUSIBLE,MatMul'."),
Flag("tf_xla_clustering_debug",
&mark_for_compilation_flags->tf_xla_clustering_debug,
"Dump graphs during XLA compilation."),

View File

@ -20,6 +20,7 @@ XLA_OPS_DEPS = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -41,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/profiler/lib/traceme.h"
@ -206,12 +208,14 @@ se::DeviceMemoryAllocator* GetAllocator(
XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
const std::vector<int>& constants,
const std::vector<int>& resources,
const NameAttrList& function)
const NameAttrList& function,
bool has_ref_vars)
: OpKernel(ctx),
constants_(constants),
resources_(resources),
function_(function),
platform_info_(PlatformInfoFromContext(ctx)) {}
platform_info_(PlatformInfoFromContext(ctx)),
has_ref_vars_(has_ref_vars) {}
static Status BuildCompilationCache(OpKernelContext* ctx,
const XlaPlatformInfo& platform_info,
@ -350,8 +354,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
{
Status s = CompileToLocalExecutable(
ctx, function_, /*has_ref_vars=*/true, platform_info_, resources_,
constants_, /*lazy=*/false, &client, &variables, &kernel, &executable);
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
&executable);
if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
platform_info_.device_type().type_string() == DEVICE_GPU)) {
// Suggest auto jit if the failure was with GPU or CPU.
@ -384,6 +389,18 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
xla::ThenExecuteFunction then_execute;
if (ctx->op_device_context()) {
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
Status status = ctx->op_device_context()->ThenExecute(
down_cast<Device*>(ctx->device()), stream, std::move(fn));
if (!status.ok()) {
// This should never happen.
LOG(ERROR) << "ThenExecute failed " << status;
}
};
run_options.set_then_execute_function(&then_execute);
}
Env* env = Env::Default();
auto start_time = env->NowMicros();
@ -462,7 +479,7 @@ bool HasRefVars(OpKernelConstruction* ctx) {
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
: XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
FunctionAttr(ctx)) {}
FunctionAttr(ctx), /*has_ref_vars=*/true) {}
XlaLocalLaunchOp::~XlaLocalLaunchOp() {
VLOG(1) << "XlaLocalLaunchOp destroyed";
@ -592,6 +609,18 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
xla::ThenExecuteFunction then_execute;
if (ctx->op_device_context()) {
then_execute = [&](se::Stream* stream, std::function<void()> fn) {
Status status = ctx->op_device_context()->ThenExecute(
down_cast<Device*>(ctx->device()), stream, std::move(fn));
if (!status.ok()) {
// This should never happen.
LOG(ERROR) << "ThenExecute failed " << status;
}
};
run_options.set_then_execute_function(&then_execute);
}
Env* env = Env::Default();
auto start_time = env->NowMicros();

View File

@ -95,12 +95,15 @@ class XlaPlatformInfo {
// in the GraphDef.
// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
// this kernel when asked to create a kernel for an XLA-compiled function.
//
// `has_ref_vars`: whether the input computation can have reference variables.
// TODO(cheshire): instead derive this information from the input graph.
class XlaLocalLaunchBase : public OpKernel {
public:
XlaLocalLaunchBase(OpKernelConstruction* ctx,
const std::vector<int>& constants,
const std::vector<int>& resources,
const NameAttrList& function);
const NameAttrList& function, bool has_ref_vars);
XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
~XlaLocalLaunchBase() override = default;
@ -115,6 +118,8 @@ class XlaLocalLaunchBase : public OpKernel {
const NameAttrList function_;
const XlaPlatformInfo platform_info_;
bool has_ref_vars_;
};
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph

View File

@ -963,6 +963,22 @@ absl::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) {
return absl::nullopt;
}
// Returns true iff the attribute `attr_name` is attached to either the node or
// to it's callee.
static bool GetNodeOrFuncAttr(Node* node, FunctionLibraryDefinition* flib_def,
const char* attr_name) {
bool out = false;
bool attr_value;
if (TryGetNodeAttr(node->attrs(), attr_name, &attr_value)) {
out |= attr_value;
}
if (flib_def->GetAttr(*node, attr_name, &attr_value).ok()) {
out |= attr_value;
}
return out;
}
Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
auto ignore_resource_ops = [&](const Node& n, bool* ignore) {
return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore);
@ -1016,16 +1032,9 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
resource_var_operation_node_id = node->id();
}
bool is_xla_compile_attr_true = false;
bool xla_compile_attr;
if (TryGetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr)) {
is_xla_compile_attr_true |= xla_compile_attr;
}
if (flib_def_->GetAttr(*node, kXlaCompileAttr, &xla_compile_attr).ok()) {
is_xla_compile_attr_true |= xla_compile_attr;
}
bool is_xla_compile_attr_true =
GetNodeOrFuncAttr(node, flib_def_, kXlaCompileAttr) ||
GetNodeOrFuncAttr(node, flib_def_, kXlaMustCompileAttr);
DeviceSet devices;
devices.Insert(device);
@ -1874,6 +1883,8 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"EmptyTensorList",
"ExtractImagePatches",
"Igamma",
"IgammaGradA",
"RandomGammaGrad",
"Igammac",
"FFT",
"FFT2D",
@ -1996,6 +2007,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"StatelessRandomNormal",
"StatelessRandomUniform",
"StatelessRandomUniformInt",
"StatelessRandomUniformFullInt",
"StatelessTruncatedNormal",
"StatelessWhile",
"Svd",

View File

@ -20,15 +20,17 @@ limitations under the License.
namespace tensorflow {
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) const {
return CanCreateXlaKernel(node_def);
bool XlaKernelCreator::CanCreateKernel(
const FunctionLibraryRuntime& flr,
const std::shared_ptr<const NodeProperties>& props) const {
return CanCreateXlaKernel(props->node_def);
}
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) const {
return CreateXlaKernel(flr, node_def, kernel);
Status XlaKernelCreator::CreateKernel(
FunctionLibraryRuntime* flr,
const std::shared_ptr<const NodeProperties>& props,
std::unique_ptr<OpKernel>* kernel) const {
return CreateXlaKernel(flr, props->node_def, kernel);
}
namespace {

View File

@ -29,11 +29,13 @@ class XlaKernelCreator : public CustomKernelCreator {
// Given a NodeDef 'node_def' and the function library runtime 'flr', returns
// true if 'node_def' is a call to a compilable function defined in 'flr',
// with the kXlaCompileAttr set.
bool CanCreateKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) const override;
bool CanCreateKernel(
const FunctionLibraryRuntime& flr,
const std::shared_ptr<const NodeProperties>& props) const override;
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
Status CreateKernel(FunctionLibraryRuntime* flr,
const std::shared_ptr<const NodeProperties>& props,
std::unique_ptr<OpKernel>* kernel) const override;
};

View File

@ -30,10 +30,12 @@ limitations under the License.
namespace tensorflow {
NodeDef ToNodeDef(const string& text) {
std::shared_ptr<NodeProperties> ToNodeProperties(const string& text) {
NodeDef node_def;
DataTypeVector dummy;
EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
return node_def;
return std::make_shared<NodeProperties>(nullptr, std::move(node_def), dummy,
dummy);
}
// Create a FunctionDef that takes one resource and one regular param
@ -98,11 +100,11 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
(*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
Init({fdef});
XlaKernelCreator xla_kernel_creator;
NodeDef callsite =
ToNodeDef(R"pb(
auto callsite =
ToNodeProperties(R"pb(
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
)pb");
(*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
(*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true);
// Note: need to set attribute on the created node.
Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
@ -127,13 +129,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) {
Init({fdef});
XlaKernelCreator xla_kernel_creator;
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
Status status =
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
}
@ -143,13 +146,14 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
Init({fdef});
XlaKernelCreator xla_kernel_creator;
Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
Status status =
xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto(
name: 'XTimesY'
op: 'XTimesY'
input: 'a'
input: 'b'
)proto"),
&kernel_);
EXPECT_TRUE(errors::IsInternal(status)) << status.ToString();
}

View File

@ -104,7 +104,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
/*compile_time_const_nodes=*/nullptr, flr));
for (int i = 0; i < const_args.size(); ++i) {
for (size_t i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
constant_arg_indices->push_back(i);
}
@ -113,7 +113,7 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
// There can be hundreds of resource variables. Reserve the space for them.
// We don't reserve for constants above as they are usually few.
resource_arg_indices->reserve(arg_types.size());
for (int i = 0; i < arg_types.size(); ++i) {
for (size_t i = 0; i < arg_types.size(); ++i) {
if (arg_types[i] == DT_RESOURCE) {
resource_arg_indices->push_back(i);
}
@ -177,7 +177,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
// 214 variables and a similar number of activations.
SinglePassSearch constants_search(&constant_arg_indices);
SinglePassSearch resources_search(&resource_arg_indices);
for (int i = 0; i < fbody->arg_types.size(); ++i) {
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
// Compile-time constants and resource handles are expected to be in
// host memory.
@ -207,7 +207,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
// in device memory except for resources.
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
for (int i = 0; i < fbody->ret_types.size(); ++i) {
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
if (fbody->ret_types[i] == DT_RESOURCE) {
output_memory_types[i] = HOST_MEMORY;
}
@ -218,15 +218,17 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(
DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()), &node_def,
&fbody->fdef.signature(), flr, dev->resource_manager(), fbody->arg_types,
input_memory_types, fbody->ret_types, output_memory_types,
flr->graph_def_version(), &s);
auto props = std::make_shared<NodeProperties>(
&fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()),
flr, dev->resource_manager(), props,
input_memory_types, output_memory_types,
flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>(
&construction, constant_arg_indices, resource_arg_indices, function);
&construction, constant_arg_indices, resource_arg_indices, function,
/*has_ref_vars=*/false);
return s;
}
} // namespace tensorflow

View File

@ -44,11 +44,9 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AffineDialectRegistration",
"@llvm-project//mlir:LoopDialectRegistration",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:Support",
"@llvm-project//mlir/test:TestTransforms",
],
@ -106,7 +104,9 @@ tf_cc_binary(
name = "tf-opt",
deps = [
":tf_mlir_opt_main",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
],
)
@ -116,8 +116,10 @@ tf_cc_binary(
srcs = ["tf_mlir_translate_main.cc"],
deps = [
":init_mlir",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
@ -129,6 +131,7 @@ tf_cc_binary(
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TranslateClParser",

View File

@ -0,0 +1,3 @@
# TensorFlow MLIR
These are the docs for: https://www.tensorflow.org/mlir

View File

@ -0,0 +1,26 @@
upper_tabs:
# Tabs left of dropdown menu
- include: /_upper_tabs_left.yaml
- include: /api_docs/_upper_tabs_api.yaml
# Dropdown menu
- name: Resources
path: /resources
is_default: true
menu:
- include: /resources/_menu_toc.yaml
lower_tabs:
# Subsite tabs
other:
- name: Guide
contents:
- title: Overview
path: /mlir/overview
- heading: Dialects
- title: Overview
path: /mlir/dialects
- title: TensorFlow
path: /mlir/tf_ops
- title: TensorFlow Lite
path: /mlir/tfl_ops
- include: /_upper_tabs_right.yaml

View File

@ -0,0 +1,54 @@
book_path: /mlir/_book.yaml
project_path: /mlir/_project.yaml
description: <!--no description-->
landing_page:
custom_css_path: /site-assets/css/style.css
rows:
- heading: MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
items:
- description: >
The <a href="https://mlir.llvm.org/" class="external">MLIR</a> project defines a common
intermediate representation (IR) that unifies the infrastructure required to execute high
performance machine learning models in TensorFlow and similar ML frameworks. This project
will include the application of HPC techniques, along with integration of
search algorithms like reinforcement learning. MLIR aims to reduce the
cost to bring up new hardware, and improve usability for existing
TensorFlow users.
- code_block: |
<pre class = "prettyprint">
// Syntactically similar to LLVM:
func @testFunction(%arg0: i32) {
%x = call @thingToCall(%arg0) : (i32) -> i32
br ^bb1
^bb1:
%y = addi %x, %x : i32
return %y : i32
}
</pre>
- classname: devsite-landing-row-cards
items:
- heading: "Multi-Level Intermediate Representation for Compiler Infrastructure"
youtube_id: qzljG6DKgic
buttons:
- label: Watch the video
path: https://www.youtube.com/watch?v=qzljG6DKgic
- heading: "A new intermediate representation and compiler framework"
image_path: /resources/images/tf-logo-card-16x9.png
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
buttons:
- label: Read on TensorFlow blog
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
- heading: MLIR on GitHub
image_path: /resources/images/github-card-16x9.png
path: https://github.com/llvm/llvm-project/tree/master/mlir
buttons:
- label: View on GitHub
path: https://github.com/llvm/llvm-project/tree/master/mlir
- heading: TensorFlow MLIR on GitHub
image_path: /resources/images/github-card-16x9.png
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir
buttons:
- label: View on GitHub
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir

View File

@ -0,0 +1,37 @@
# MLIR dialects
## Overview
To separate different hardware and software targets, MLIR has “dialects”,
including:
* TensorFlow IR, which represents all things possible in TensorFlow graphs.
* XLA HLO IR, which is designed to take advantage of XLAs compilation
abilities (with output to, among other things, TPUs).
* An experimental affine dialect, which focuses on
[polyhedral representations](https://en.wikipedia.org/wiki/Polytope_model)
and optimizations.
* LLVM IR, which has a 1:1 mapping between it and LLVMs own representation,
allowing MLIR to emit GPU and CPU code through LLVM.
* TensorFlow Lite, which will translate to running code on mobile platforms.
Each dialect consists of a set of defined operations which have invariants
placed on them, like: “This is a binary operator, and the inputs and outputs
have the same types.”
## Adding to MLIR
MLIR has no fixed/built-in list of globally known operations (no “intrinsics”).
Dialects can define entirely custom types, which is how MLIR can model things
like the LLVM IR type system (which has first class aggregates), domain
abstractions important for ML-optimized accelerators like quantized types, and
even the Swift or Clang type systems (which are built around Swift/Clang
declaration nodes) in the future.
If you want to connect a new low-level compiler, you would create a new dialect
and the lowerings between the TensorFlow Graph dialect and your dialect.
This smooths the path for hardware and compiler makers. You can even target
dialects at different levels in the same model; the higher-level optimizers
will respect the unfamiliar parts of the IR and wait for a lower level to handle
it.

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 148 KiB

View File

@ -0,0 +1,36 @@
# MLIR
## Overview
MLIR, or Multi-Level Intermediate Representation, is a representation format
and library of compiler utilities that sits between the model representation
and low-level compilers/executors that generate hardware-specific code.
MLIR is, at its heart, a flexible infrastructure for modern optimizing
compilers. This means it consists of a specification for intermediate
representations (IR) and a code toolkit to perform transformations on that
representation. (In compiler parlance, as you move from higher-level
representations to lower-level representations, these transformations can be
called “lowerings”)
MLIR is highly influenced by [LLVM](https://llvm.org/) and unabashedly reuses
many great ideas from it. It has a flexible type system, and allows
representing, analyzing and transforming graphs combining multiple levels of
abstraction in the same compilation unit. These abstractions include TensorFlow
operations, nested polyhedral loop regions, and even LLVM instructions and fixed
hardware operations and types.
We expect MLIR to be of interest to many groups, including:
* Compiler researchers and implementers looking to optimize performance and
memory consumption of machine learning models
* Hardware makers looking for a way to connect their hardware to TensorFlow,
such as TPUs, portable neural hardware in phones, and other custom ASICs
* People writing language bindings that want to take advantage of optimizing
compilers and hardware acceleration.
The TensorFlow ecosystem contains a number of compilers and optimizers that
operate at multiple levels of the software and hardware stack. We expect the
gradual adoption of MLIR to simplify every aspect of this stack.
<img alt="MLIR overview diagram" src="./images/mlir-infra.svg"/>

View File

@ -48,10 +48,11 @@ def _run_lit_test(name, data, size, tags, driver, features):
" the driver parameter when running this test. If you require" +
" custom driver support, please file an issue to request it.")
# Disable tests on windows for now, to enable testing rest of all xla and mlir.
native.py_test(
name = name,
srcs = ["@llvm-project//llvm:lit"],
tags = tags,
tags = tags + ["no_windows"],
args = [
"tensorflow/compiler/mlir/" + paths.basename(data[-1]) + " --config-prefix=runlit -v",
] + features,

View File

@ -208,6 +208,7 @@ cc_library(
"ir/tfl_ops.h.inc",
"ir/tfl_ops_interface.cc.inc",
"ir/tfl_ops_interface.h.inc",
"runtime_verifiers.inc",
"utils/attribute_utils.cc",
],
hdrs = [
@ -231,6 +232,7 @@ cc_library(
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
"@llvm-project//mlir:Transforms",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
@ -302,12 +304,14 @@ cc_library(
"transforms/optimize_functional_ops.cc",
"transforms/prepare_composite_functions_tf.cc",
"transforms/prepare_tf.cc",
"transforms/runtime_type_verify.cc",
"transforms/split_merged_operands.cc",
"transforms/trim_functions_tf.cc",
"transforms/unroll_batch_matmul.cc",
"transforms/while_loop_outline.cc",
],
hdrs = [
"ir/tfl_ops_interface.h.inc",
"transforms/dilated_conv.h",
"transforms/passes.h",
"transforms/unroll_batch_matmul.h",
@ -323,6 +327,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
@ -459,9 +464,9 @@ cc_library(
)
tf_native_cc_binary(
name = "operator-converter-gen",
name = "converter-gen",
srcs = [
"operator_converter_gen.cc",
"converter_gen.cc",
],
deps = [
"@llvm-project//llvm:support",
@ -471,14 +476,18 @@ tf_native_cc_binary(
)
gentbl(
name = "operator_converter_inc",
name = "converter_inc",
tbl_outs = [
(
"", # This driver has no options.
"--gen-operator-converters",
"operator_converters.inc",
),
(
"--gen-runtime-verifiers",
"runtime_verifiers.inc",
),
],
tblgen = ":operator-converter-gen",
tblgen = ":converter-gen",
td_file = "ir/tfl_ops.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
@ -561,6 +570,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -581,8 +591,6 @@ cc_library(
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:StandardDialectRegistration",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation",
@ -594,6 +602,7 @@ tf_cc_binary(
name = "flatbuffer_translate",
deps = [
":flatbuffer_translate_lib",
"@llvm-project//mlir:LoopOpsTransforms",
"@llvm-project//mlir:MlirTranslateMain",
],
)
@ -643,12 +652,14 @@ tf_cc_binary(
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
)
@ -696,7 +707,6 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:Transforms",
],
)
@ -730,7 +740,6 @@ cc_library(
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:QuantOpsDialectRegistration",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],

View File

@ -35,7 +35,8 @@ struct PassConfig {
skip_control_dialect(false),
form_clusters(false),
inline_functions(true),
unfold_batch_matmul(true) {}
unfold_batch_matmul(true),
legalize_tf_while(true) {}
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
// added, which produces TF Lite ops.
@ -61,6 +62,10 @@ struct PassConfig {
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
// of tfl.fully_connected ops.
bool unfold_batch_matmul;
// Whether to legalize TF While to TFL While.
// Note: This is staging step and will be removed.
// TODO(b/137395003): Remove post switching legalization.
bool legalize_tf_while;
};
} // namespace TFL

View File

@ -28,6 +28,9 @@ limitations under the License.
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Attribute.h" // TF:llvm-project
#include "mlir/TableGen/Format.h" // TF:llvm-project
#include "mlir/TableGen/Operator.h" // TF:llvm-project
#include "mlir/TableGen/Predicate.h" // TF:llvm-project
using llvm::DefInit;
using llvm::dyn_cast;
@ -41,6 +44,19 @@ using llvm::SmallVector;
using llvm::StringInit;
using llvm::StringRef;
enum ActionType {
OpConv,
RuntimeVerify,
};
// NOLINTNEXTLINE
llvm::cl::opt<ActionType> action(
llvm::cl::desc("Action to perform:"),
llvm::cl::values(clEnumValN(OpConv, "gen-operator-converters",
"Generate operator converters"),
clEnumValN(RuntimeVerify, "gen-runtime-verifiers",
"Generate TFLite runtime verifiers")));
// Returns the associated option name for the given op definition.
static inline std::string GetOperatorOptionName(const Record &def) {
assert(def.getName().startswith("TFL_") && "unexpected op prefix");
@ -342,8 +358,101 @@ static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) {
return false;
}
static void GenOperandResultVerifier(raw_ostream &os,
llvm::ArrayRef<llvm::Init *> values,
StringRef valueKind) {
mlir::tblgen::FmtContext fctx;
bool first = true;
for (auto static_value : llvm::enumerate(values)) {
auto *definit = llvm::cast<llvm::DefInit>(static_value.value());
auto *val = definit->getDef()->getValue("tflRuntimeTypePredicate");
if (!val) continue;
// Create code block on first type to verify.
if (first) {
os << " {\n";
os << " unsigned index = " << static_value.index() << ";\n";
first = false;
}
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
auto desc =
definit->getDef()->getValueAsString("tflRuntimeTypeDescription");
// Emit a loop to check all the dynamic values in the pack.
os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n",
// Capitalize the first letter to match the function name
valueKind.substr(0, 1).upper(), valueKind.substr(1),
static_value.index());
os << " (void)v;\n"
<< " if (!("
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
<< formatv(
" return op->emitOpError(\"{0} #\") << index "
"<< \" must be {1}, but got \" << v.getType();\n",
valueKind, desc)
<< " }\n" // if
<< " ++index;\n"
<< " }\n"; // for
}
// Emit closing brace if needed.
if (!first) os << " }\n";
}
// NOLINTNEXTLINE
static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
emitSourceFileHeader("MLIR TFLite Runtime Verifiers", os);
// Retrieve all the definitions derived from TFL_Op and sort by record name.
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
llvm::sort(defs, LessRecord());
// Iterate through all the ops defined.
for (const auto *def : defs) {
mlir::tblgen::Operator op(*def);
if (!op.getTrait("TflRuntimeVerifyOpInterface::Trait")) continue;
mlir::tblgen::FmtContext verify_ctx;
os << "::mlir::LogicalResult " << op.getCppClassName()
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n";
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
verify_ctx.withOp("top");
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
auto &value = op.getOperand(i);
// Skip from from first variadic operands for now. Else getOperand index
// used below doesn't match.
if (value.isVariadic()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getOperand({0})", i));
}
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
auto &value = op.getResult(i);
// Skip from from first variadic results for now. Else getResult index
// used below doesn't match.
if (value.isVariadic()) break;
if (!value.name.empty())
verify_ctx.addSubst(value.name, formatv("op->getResult({0})", i));
}
}
GenOperandResultVerifier(os, def->getValueAsDag("arguments")->getArgs(),
"operand");
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
"result");
os << " return mlir::success();\n}\n";
}
return false;
}
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv);
return TableGenMain(argv[0], &OperatorWritersMain);
if (action == ActionType::OpConv)
return TableGenMain(argv[0], &OperatorWritersMain);
return TableGenMain(argv[0], &RuntimeVerifierWriterMain);
}

View File

@ -46,7 +46,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
@ -76,6 +76,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -124,6 +125,20 @@ static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
llvm::cl::init(false));
// NOLINTNEXTLINE
static opt<std::string> input_arrays_flag(
"input-arrays",
llvm::cl::desc(
"List of input tensors, if different from the default inputs"),
llvm::cl::init(""));
// NOLINTNEXTLINE
static opt<std::string> output_arrays_flag(
"output-arrays",
llvm::cl::desc(
"List of output tensors, if different from the default outputs"),
llvm::cl::init(""));
namespace {
bool IsScalar(const TensorT& tensor) {
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
@ -590,6 +605,11 @@ StatusOr<Operation*> ConvertOp(
op_state.addTypes({type});
}
if (op_name == "tfl.lstm") {
// TODO(b/147587779): add the right region if region is empty.
op_state.addRegion();
}
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
if (IsCustomOp(op_name)) {
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
@ -610,43 +630,30 @@ StatusOr<Operation*> ConvertOp(
return builder.createOperation(op_state);
}
// Returns the output tensor indices for the given subgraph. If
// ordered_output_arrays is provided, then return the tensor indices in
// ordered_output_arrays.
StatusOr<llvm::SmallVector<int32_t, 4>> GetOutputTensorIndices(
const tflite::SubGraphT& subgraph, Location base_loc,
const std::vector<std::string>& ordered_output_arrays) {
if (ordered_output_arrays.empty()) {
return llvm::SmallVector<int32_t, 4>(subgraph.outputs.begin(),
subgraph.outputs.end());
// Returns indices of the given tensors in the subgraph. Returns error if a
// tensor name cannot be found in the subgraph.
StatusOr<std::vector<int>> GetTensorIndices(
const tflite::SubGraphT& subgraph,
const std::vector<std::string>& tensor_names) {
absl::flat_hash_map<std::string, int> name_to_index;
for (auto index_and_tensor : llvm::enumerate(subgraph.tensors)) {
name_to_index[index_and_tensor.value()->name] = index_and_tensor.index();
}
llvm::SmallVector<int32_t, 4> outputs;
outputs.resize(ordered_output_arrays.size());
absl::flat_hash_map<std::string, int> output_order_map;
for (auto output : llvm::enumerate(ordered_output_arrays)) {
output_order_map[output.value()] = output.index();
}
std::vector<int> indices;
indices.reserve(tensor_names.size());
int tensor_index = 0;
int found_output_tensors = 0;
for (const auto& tensor : subgraph.tensors) {
auto found = output_order_map.find(tensor->name);
if (found != output_order_map.end()) {
const int output_index = found->second;
outputs[output_index] = tensor_index;
++found_output_tensors;
for (const auto& name : tensor_names) {
auto found = name_to_index.find(name);
if (found != name_to_index.end()) {
indices.push_back(found->second);
} else {
return errors::InvalidArgument("could not find tensor in subgraph: ",
name);
}
++tensor_index;
}
if (found_output_tensors != ordered_output_arrays.size()) {
auto err = errors::InvalidArgument(
"cannot find all nodes in ordered_output_arrays");
return emitError(base_loc, err.ToString()), err;
}
return outputs;
return indices;
}
// Given a list of tensor indices, returns a string of concatenated tensor names
@ -661,15 +668,18 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute(
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
}
// Given a list of output indices, traverses the subgraph and returns the set of
// ops that are ancestors of the output tensors.
// Traverses the subgraph from output_indices to input_indices and returns the
// set of ops that are visited.
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> output_indices) {
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> input_indices,
ArrayRef<int32_t> output_indices) {
// Create a map from tensor index to defining op.
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
for (const auto& op : subgraph.operators) {
for (int32_t output : op->outputs) {
defining_op[output] = op.get();
if (!llvm::is_contained(input_indices, output)) {
defining_op[output] = op.get();
}
}
}
@ -718,18 +728,40 @@ StatusOr<FuncOp> ConvertSubgraph(
const std::vector<std::string>& op_names,
const std::vector<std::string>& func_names,
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
Location base_loc, Builder builder,
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
Location base_loc, Builder builder, bool is_entry_point,
bool use_external_constant,
const std::vector<std::string>& ordered_input_arrays,
const std::vector<std::string>& ordered_output_arrays,
bool experimental_prune_unreachable_nodes_unconditionally) {
llvm::SmallVector<mlir::Type, 2> ret_types;
llvm::SmallVector<mlir::Type, 4> input_types;
auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
// Construct function type
for (auto input : subgraph.inputs) {
auto& tensor = *subgraph.tensors.at(input);
std::vector<int> func_inputs = subgraph.inputs;
if (is_entry_point && !ordered_input_arrays.empty()) {
if (!experimental_prune_unreachable_nodes_unconditionally) {
// TODO(b/149922113): Resolve input-arrays/pruning flags interaction.
return errors::InvalidArgument(
"input-arrays should be used with experimental pruning flag");
}
TF_ASSIGN_OR_RETURN(func_inputs,
GetTensorIndices(subgraph, ordered_input_arrays));
}
// Add state variables to inputs.
absl::flat_hash_set<int32_t> input_index_set(func_inputs.begin(),
func_inputs.end());
for (int i = 0; i < subgraph.tensors.size(); i++) {
auto& tensor = *subgraph.tensors.at(i);
if (tensor.is_variable && !input_index_set.contains(i)) {
func_inputs.emplace_back(i);
input_index_set.insert(i);
}
}
for (auto input_or_variable : func_inputs) {
auto& tensor = *subgraph.tensors.at(input_or_variable);
// TODO(b/138222071) Graph inputs must have static shape per the exporter,
// but we cannot differentiate scalars from unranked tensors.
// Here we reverse the default assumption that shape = [] means unranked.
@ -753,9 +785,11 @@ StatusOr<FuncOp> ConvertSubgraph(
}
}
TF_ASSIGN_OR_RETURN(
auto func_outputs,
GetOutputTensorIndices(subgraph, base_loc, ordered_output_arrays));
std::vector<int> func_outputs = subgraph.outputs;
if (is_entry_point && !ordered_output_arrays.empty()) {
TF_ASSIGN_OR_RETURN(func_outputs,
GetTensorIndices(subgraph, ordered_output_arrays));
}
for (auto output : func_outputs) {
bool is_constant = !is_op_output[output];
@ -782,8 +816,8 @@ StatusOr<FuncOp> ConvertSubgraph(
Value maybe_optional_arg_marker = nullptr;
// Get or construct MLIR values for each input
for (int i = 0, e = subgraph.inputs.size(); i < e; i++) {
auto input_tensor = subgraph.inputs[i];
for (int i = 0, e = func_inputs.size(); i < e; i++) {
auto input_tensor = func_inputs[i];
const auto& tensor = *subgraph.tensors.at(input_tensor);
auto loc = TensorLoc(tensor, builder, base_loc);
if (vals_map[input_tensor]) {
@ -806,9 +840,9 @@ StatusOr<FuncOp> ConvertSubgraph(
// Set tf.entry_function attribute
if (is_entry_point) {
llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
if (!subgraph.inputs.empty()) {
if (!func_inputs.empty()) {
attributes.push_back(BuildTFEntryFunctionAttribute(
subgraph, &builder, "inputs", subgraph.inputs));
subgraph, &builder, "inputs", func_inputs));
}
if (!func_outputs.empty()) {
attributes.push_back(BuildTFEntryFunctionAttribute(
@ -820,7 +854,7 @@ StatusOr<FuncOp> ConvertSubgraph(
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
if (experimental_prune_unreachable_nodes_unconditionally) {
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
PruneSubgraph(subgraph, func_outputs));
PruneSubgraph(subgraph, func_inputs, func_outputs));
}
// Construct MLIR operators from TFLite operators
@ -931,8 +965,9 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
OwningModuleRef tflite::FlatBufferToMlir(
absl::string_view buffer, MLIRContext* context, Location base_loc,
const std::vector<std::string>& ordered_output_arrays,
bool use_external_constant,
const std::vector<std::string>& ordered_input_arrays,
const std::vector<std::string>& ordered_output_arrays,
bool experimental_prune_unreachable_nodes_unconditionally) {
auto model_ptr =
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
@ -971,33 +1006,25 @@ OwningModuleRef tflite::FlatBufferToMlir(
builder.getStringAttr(model->description));
}
if (!ordered_output_arrays.empty() && model->subgraphs.size() > 1) {
// TODO(b/141485522): support more than one subgraph.
return emitError(base_loc,
"ordered_output_arrays does not support more than one "
"subgraph yet"),
nullptr;
}
for (auto e : llvm::enumerate(model->subgraphs)) {
auto& subgraph = e.value();
std::string name = SubgraphName(e.index(), *subgraph);
auto func_or_error = ConvertSubgraph(
*subgraph, name, operator_names, func_names, model->buffers, base_loc,
// Only the entry point needs pseudo_input_ops
builder,
// TODO(b/131175224,b/132239787) Support multiple entry points
builder, ordered_output_arrays,
/*is_entry_point=*/e.index() == 0,
/*use_external_constant=*/use_external_constant,
/*use_external_constant=*/use_external_constant, ordered_input_arrays,
ordered_output_arrays,
experimental_prune_unreachable_nodes_unconditionally);
if (!func_or_error.ok()) {
return emitError(base_loc, "could not translate function ")
<< subgraph->name,
<< subgraph->name << ": "
<< func_or_error.status().error_message(),
nullptr;
}
module.push_back(func_or_error.ConsumeValueOrDie());
}
// TFLite subgraphs do not necessarily have names,
return OwningModuleRef(module);
}
@ -1012,17 +1039,24 @@ static OwningModuleRef FlatBufferFileToMlirTrans(
auto loc =
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
// Parses output_arrays_order from command line option.
// Parses input/output names from command line options.
std::vector<std::string> inputs;
std::vector<std::string> outputs;
if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &outputs).ok()) {
// Use output parser since we only have tensor names.
if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) {
return emitError(loc, "parsing input array info failed ")
<< input_arrays_flag,
nullptr;
}
if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) {
return emitError(loc, "parsing output array info failed ")
<< output_arrays_string,
<< output_arrays_flag,
nullptr;
}
return tflite::FlatBufferToMlir(
absl::string_view(input->getBufferStart(), input->getBufferSize()),
context, loc, outputs, use_external_constant,
context, loc, use_external_constant, inputs, outputs,
experimental_prune_unreachable_nodes_unconditionally);
}

View File

@ -35,9 +35,9 @@ namespace tflite {
// are not ancestors of the output nodes will be pruned.
mlir::OwningModuleRef FlatBufferToMlir(
absl::string_view buffer, mlir::MLIRContext* context,
mlir::Location base_loc,
const std::vector<std::string>& ordered_output_arrays,
bool use_external_constant = false,
mlir::Location base_loc, bool use_external_constant = false,
const std::vector<std::string>& ordered_input_arrays = {},
const std::vector<std::string>& ordered_output_arrays = {},
bool experimental_prune_unreachable_nodes_unconditionally = false);
} // namespace tflite

View File

@ -42,7 +42,7 @@ limitations under the License.
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
@ -122,8 +122,6 @@ bool emit_custom_ops;
bool emit_select_tf_ops;
bool lower_tensor_list_ops;
bool strip_debug_info;
// NOLINTNEXTLINE
std::string output_arrays_string;
// NOLINTNEXTLINE
static opt<bool, true> emit_builtin_tflite_ops_flag(
@ -156,11 +154,6 @@ static opt<bool, true> strip_debug_info_flag(
"strip-debug-info", llvm::cl::desc("Strip debug info during export"),
llvm::cl::location(strip_debug_info), llvm::cl::init(false));
// NOLINTNEXTLINE
static opt<std::string, true> output_arrays_flag(
"output-arrays", llvm::cl::desc("List of output tensors"),
llvm::cl::location(output_arrays_string), llvm::cl::init(""));
ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
// Use initial buffer size in flatbuffer builder to be same as the initial size
@ -172,7 +165,7 @@ constexpr size_t kInitialBufferSize = 10240;
// `isSigned` is set to false for other types.
static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
bool is_signed = true) {
if (!is_signed && type.isInteger(8)) {
if (!is_signed && type.isSignlessInteger(8)) {
return tflite::TensorType_UINT8;
}
if (!is_signed) {

View File

@ -27,7 +27,5 @@ extern bool emit_custom_ops;
extern bool lower_tensor_list_ops;
// The flag to control whether debug info gets stripped on export.
extern bool strip_debug_info;
// The flag to control the output array info of tflite graph.
extern std::string output_arrays_string;
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_

View File

@ -71,4 +71,23 @@ def TFL_SparseOp : OpInterface<"SparseOpInterface"> {
];
}
//===----------------------------------------------------------------------===//
// TFL runtime type verification of operand/result types.
def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
let description = [{
Interface to verify TFLite runtime op verification.
This verifies that the converted TFLite ops has operand/result type
supported by the TFLite runtime.
}];
let methods = [
StaticInterfaceMethod<
[{Returns whether the op's operands/results are supported by runtime.}],
"LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op)
>,
];
}
#endif // TFL_OP_INTERFACES

View File

@ -23,9 +23,10 @@ limitations under the License.
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
@ -36,6 +37,7 @@ limitations under the License.
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
@ -273,7 +275,7 @@ Attribute ConstFoldBinaryOp(
return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
float_calculate, is_commutative);
if (elemType.isa<IntegerType>())
if (elemType.isSignlessInteger())
return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
int_calculate, is_commutative);
@ -721,12 +723,11 @@ static LogicalResult Verify(PackOp op) {
}
// Make sure all inputs have the same shape and element type.
// TODO(rahulsp): Simplify once b/135032064 is fixed.
for (Value operand : op.getOperands()) {
auto other_type = operand.getType().cast<ShapedType>();
if (input_type != other_type)
// TODO(b/135032063): Simplify once fixed.
for (Type operand_type : op.getOperandTypes()) {
if (failed(mlir::verifyCompatibleShape(input_type, operand_type)))
return op.emitOpError("operands should be of the same type. got ")
<< input_type << ", " << other_type;
<< input_type << ", " << operand_type;
}
return success();
@ -1106,10 +1107,10 @@ static LogicalResult VerifySplitOpOutputTypes(
for (int64_t i = 0; i < num_splits; ++i) {
auto expected_output_type = get_expected_output_type(i);
Value output = op->getResult(i);
auto output_type = output.getType().dyn_cast<RankedTensorType>();
if (!output_type || output_type != expected_output_type)
if (failed(verifyCompatibleShape(output.getType(), expected_output_type)))
return op->emitOpError()
<< "output #" << i << " should be " << expected_output_type;
<< "output #" << i << " should be " << expected_output_type
<< " instead got " << output.getType();
}
return success();
}
@ -1559,7 +1560,7 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
limit_tensor.getType().getRank() == 0 &&
delta_tensor.getType().getRank() == 0);
Type elem_type = getType().cast<ShapedType>().getElementType();
if (elem_type.isa<IntegerType>()) {
if (elem_type.isSignlessInteger()) {
auto start_attr = start_tensor.getValue<IntegerAttr>({});
auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
@ -1661,7 +1662,7 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
// Do not try to fold elements attr of a quant type because
// DenseElementsAttr does not support it.
if (!getType().cast<ShapedType>().getElementType().isIntOrFloat())
if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat())
return nullptr;
assert(perm_tensor.getType().getRank() == 1);
@ -1741,47 +1742,108 @@ static LogicalResult Verify(TransposeOp op) {
return success();
}
LogicalResult Verify(WhileOp op) {
if (op.getNumOperands() != op.getNumResults())
return op.emitOpError(llvm::formatv(
"number of operands does not match number of results ({0} != {1})",
op.getNumOperands(), op.getNumResults()));
// TODO(jpienaar): Verify operand, result & block arguments types
return success();
}
namespace {
struct WhileResultOperandsMatch : public OpRewritePattern<WhileOp> {
// Canonicalize While op so that results and operands match and external values
// are via implicit capture rather than via block args.
struct WhileResultOperandsMatchAndImplicitCapture
: public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(WhileOp while_op,
PatternRewriter &rewriter) const override {
auto size = while_op.body().front().getArguments().size();
Operation *op = while_op.getOperation();
auto old_size = op->getNumResults();
// No change needed as the number of operands match the number of results.
if (size == old_size) return matchFailure();
// Replace values simply passed through the body with extern values. The
// block arguments of body and while match and so the corresponding cond
// argument can be easily found.
bool unchanged = true;
auto &body_block = while_op.body().front();
auto &cond_block = while_op.cond().front();
auto &yield = *body_block.getTerminator();
for (auto ba : body_block.getArguments()) {
if (ba == yield.getOperand(ba.getArgNumber())) {
unchanged = false;
auto value = while_op.getOperand(ba.getArgNumber());
ba.replaceAllUsesWith(value);
cond_block.getArgument(ba.getArgNumber()).replaceAllUsesWith(value);
}
}
// Collect the new types by combining results of old op with additional
// operand results.
// The While ops operands and result types need to match
SmallVector<Value, 4> new_operands;
SmallVector<Value, 4> new_body_yield;
SmallVector<bool, 4> const_operand(while_op.getNumOperands(), false);
llvm::SmallVector<Type, 4> types;
types.reserve(size);
for (auto type : while_op.getResultTypes()) types.push_back(type);
for (auto arg : while_op.body().front().getArguments().drop_front(old_size))
types.push_back(arg.getType());
// Collect operands.
llvm::SmallVector<Value, 8> operands;
operands.reserve(while_op.getNumOperands());
for (auto operand : while_op.getOperands()) operands.push_back(operand);
new_operands.reserve(while_op.getNumOperands());
new_body_yield.reserve(while_op.getNumOperands());
types.reserve(while_op.getNumOperands());
// Remove block arguments not used in either cond or body. This leaves the
// block arguments of body and cond matching still.
int arg_index = 0;
for (int while_index = 0, e = while_op.getNumOperands(); while_index < e;
++while_index) {
auto value = while_op.getOperand(while_index);
if (body_block.getArgument(arg_index).use_empty() &&
cond_block.getArgument(arg_index).use_empty() &&
// This could be relaxed and casts inserted.
while_op.getResult(while_index).getType() == value.getType()) {
unchanged = false;
body_block.eraseArgument(arg_index);
cond_block.eraseArgument(arg_index);
// Mark operand as constant and replace all uses with input to while.
while_op.getResult(while_index).replaceAllUsesWith(value);
const_operand[while_index] = true;
} else {
new_operands.push_back(value);
new_body_yield.push_back(yield.getOperand(while_index));
auto type = while_op.getResult(while_index).getType();
types.push_back(type);
++arg_index;
}
}
// Done if no values removed from blocks and operands & results match.
if (unchanged) return matchFailure();
// Replace with new While with matching operands and results.
Operation *op = while_op.getOperation();
Operation *new_op = rewriter.insert(
Operation::create(op->getLoc(), op->getName(), types, operands,
Operation::create(op->getLoc(), op->getName(), types, new_operands,
op->getAttrs(), {}, /*numRegions=*/2,
/*resizableOperandList=*/true));
for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
rewriter.replaceOp(op,
new_op->getResults().take_front(op->getNumResults()));
int new_index = 0;
for (int op_index = 0, e = op->getNumResults(); op_index < e; ++op_index) {
if (const_operand[op_index]) continue;
op->getResult(op_index).replaceAllUsesWith(new_op->getResult(new_index));
++new_index;
}
rewriter.eraseOp(op);
Block &new_body_block = cast<WhileOp>(new_op).body().front();
rewriter.setInsertionPointToEnd(&new_body_block);
rewriter.replaceOpWithNewOp<YieldOp>(new_body_block.getTerminator(),
new_body_yield);
return matchSuccess();
}
};
} // namespace
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<WhileResultOperandsMatch>(context);
results.insert<WhileResultOperandsMatchAndImplicitCapture>(context);
}
Region &WhileOp::getLoopBody() { return body(); }
@ -1809,6 +1871,7 @@ LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc"
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc"
#include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc"
Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder,
Attribute value,

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,7 @@ cc_library(
],
deps = [
"//tensorflow/compiler/mlir/lite:common",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
"//tensorflow/compiler/mlir/tensorflow",

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/core/framework/graph.pb.h"
@ -62,6 +63,41 @@ const char kDetectionPostProcessOp[] =
"'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
"name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
const char kUnidirectionalSequenceLstmOp[] =
"name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
"DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
"input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
"name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
"'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
"'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
"'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
"'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
"'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
"'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
"'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
"'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
"type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
"input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
"'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
"type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
"input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
"name: 'InputCellStateTensor' type: DT_FLOAT } "
"output_arg: { name: 'Concat' type: DT_FLOAT} "
"output_arg: { name: "
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
const char kUnidirectionalSequenceRnnOp[] =
"name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
"DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
"input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
"name: 'Bias' type: DT_FLOAT} "
"input_arg: { name: 'HiddenState' type: DT_FLOAT} "
"output_arg: { name: "
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
"DT_FLOAT} "
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
// conversion mapping for constants defined in TFLite Python API.
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
@ -259,6 +295,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
toco_flags.custom_opdefs().end());
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
TF_ASSIGN_OR_RETURN(
@ -277,6 +315,11 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
pass_config.lower_tensor_list_ops = true;
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm);
// Convert back to outlined while format for export back to flatbuffer.
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
}
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
#include "mlir/IR/AffineMap.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project

View File

@ -61,11 +61,9 @@ TfLiteStatus QuantizeModel(
std::string serialized_model(
reinterpret_cast<const char*>(input_builder.GetBufferPointer()),
input_builder.GetSize());
std::vector<std::string> output_arrays_order;
OwningModuleRef module =
tflite::FlatBufferToMlir(serialized_model, &context,
UnknownLoc::get(&context), output_arrays_order);
OwningModuleRef module = tflite::FlatBufferToMlir(serialized_model, &context,
UnknownLoc::get(&context));
if (!module) {
error_reporter->Report("Couldn't import flatbuffer to MLIR.");
return kTfLiteError;

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
@ -150,7 +150,8 @@ struct QuantizationPattern : public RewritePattern {
explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
float error_tolerance, bool single_layer_verify)
: RewritePattern(DQ::getOperationName(), 1, context),
// Set the score to a large number so it is always preferred.
: RewritePattern(DQ::getOperationName(), 300, context),
enable_verify(enable_verify),
error_tolerance(error_tolerance),
single_layer_verify(single_layer_verify) {}
@ -190,7 +191,7 @@ struct QuantizationPattern : public RewritePattern {
auto ele_type = operand.getType().cast<TensorType>().getElementType();
if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
inputs.push_back(op_inst.input());
} else if (ele_type.isa<IntegerType>()) {
} else if (ele_type.isSignlessInteger()) {
// If the operand is an integer tensor, then it doesn't require the
// DQ op in the pattern.
inputs.push_back(operand);
@ -224,7 +225,7 @@ struct QuantizationPattern : public RewritePattern {
auto user = llvm::cast<Q>(*result.user_begin());
outputs_replaced.insert({user.output(), enumerated_result.index()});
output_types.push_back(user.getType());
} else if (result_ele_type.template isa<IntegerType>()) {
} else if (result_ele_type.isSignlessInteger()) {
// If the result is an integer tensor, then it doesn't require the
// D op in the pattern.
outputs_replaced.insert({result, enumerated_result.index()});

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project

View File

@ -48,11 +48,9 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
std::string serialized_model(
reinterpret_cast<const char*>(input_builder.GetBufferPointer()),
input_builder.GetSize());
std::vector<std::string> output_arrays_order;
OwningModuleRef module =
tflite::FlatBufferToMlir(serialized_model, &context,
UnknownLoc::get(&context), output_arrays_order);
OwningModuleRef module = tflite::FlatBufferToMlir(serialized_model, &context,
UnknownLoc::get(&context));
if (!module) {
error_reporter->Report("Couldn't import flatbuffer to MLIR.");
return kTfLiteError;

View File

@ -15,5 +15,6 @@ filegroup(
data = [
"//tensorflow/compiler/mlir:tf-opt",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -test-constant-fold | FileCheck %s --dump-input-on-failure
// RUN: tf-opt %s -canonicalize | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: @add_float
func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {

View File

@ -10,9 +10,7 @@ glob_lit_tests(
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = [
"pbtxt",
# TODO(fengliuai): reenable these tests after the fused loc is
# supported in the diagnostic handler.
# "py",
# "py", TODO(b/150304798)
],
)

View File

@ -27,6 +27,20 @@ func @testDilatedConvWithNonZeroSTBPadding(%arg0: tensor<1x128x128x3xf32>, %arg1
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
}
func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
return %2 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedConvWithNonTrivialDilations
// CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
// CHECK-NEXT: return [[RESULT]]
}
func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
@ -104,7 +118,7 @@ func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1:
func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
@ -115,7 +129,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
// CHECK-LABEL: testDilatedConvWithExpandSqueeze1
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@ -125,7 +139,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
@ -136,7 +150,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@ -146,7 +160,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
@ -157,7 +171,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
// CHECK-LABEL: testDilatedConvWithExpandSqueeze2
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@ -167,7 +181,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
@ -178,7 +192,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@ -188,7 +202,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
@ -200,7 +214,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
// CHECK-LABEL: testDilatedConvWithExpandSqueeze3
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@ -210,7 +224,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
@ -222,10 +236,29 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
}
func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128x1xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [2]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64x1xf32>
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32>
return %4 : tensor<1x128x128x1xf32>
// CHECK-LABEL: testDilatedConvWithDifferentExpandSqueezeAxis
// CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
// CHECK-NEXT: return [[RESULT]]
}

View File

@ -178,15 +178,20 @@ func @inputsAfterOutputs() {
// -----
// expected-error@+1 {{Found malformed ophint regions: missing inputs or outputs.}}
module {
func @extractOphintFailure() {
func @extractOphintSame() {
%0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32>
%1 = call @AnotherFunc(%0) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
%2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
%3 = "tf.Mul"(%2, %1) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
%4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
return
// CHECK: [[VAL_0:%.*]] = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32>
// CHECK: [[VAL_1:%.*]] = call @AnotherFunc([[VAL_0]]) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
// CHECK: [[VAL_2:%.*]] = "tf.Sigmoid"([[VAL_1]]) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
// CHECK: [[VAL_3:%.*]] = "tf.Mul"([[VAL_2]], [[VAL_1]]) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
// CHECK: [[VAL_4:%.*]] = "tf.Identity"([[VAL_3]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
}
func @AnotherFunc(%arg0: tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> {

View File

@ -0,0 +1,13 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -input-arrays=squared_difference --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// Tests -input-arrays flag.
func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference")
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
return %2 : tensor<4xf32>
// CHECK-LABEL: main
// CHECK-NOT: tfl.squared_difference
// CHECK: tfl.mul %[[CONST:.*]], %arg0
}

View File

@ -0,0 +1,15 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// Ensure lstm roundtrip exactly
func @main(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>) -> tensor<4 x f32> {
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %24 : tensor<4xf32>
// CHECK-LABEL: main
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %[[RES0]]
}

View File

@ -1,25 +1,31 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// Check to see if function references in while loops are preserved
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// TODO(b/138222071) Expect first output to be a scalar
// CHECK: %{{.*}}:2 = "tf.While"(%{{.*}}, %{{.*}}) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<1xf32>) -> (tensor<*xi32>, tensor<1xf32>)
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// While %arg0 is greater than zero, element wise add %arg1 with itself.
%0:2 = "tf.While"(%arg0, %arg1) {
cond = @cond, body = @body, is_stateless = false
} : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
%0:2 = "tfl.while"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
%1 = call @cond(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> tensor<i1>
"tfl.yield"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
%1:2 = call @body(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>)
"tfl.yield"(%1#0, %1#1) : (tensor<*xi32>, tensor<*xf32>) -> ()
}) {is_stateless = false} : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
return %0#1 : tensor<1xf32>
}
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
%0 = "std.constant" () {value = dense<0> : tensor<i32>} : () -> tensor<i32> loc("Const")
%1 = "tfl.greater"(%arg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
return %1 : tensor<i1>
%cst = constant dense<0> : tensor<i32> loc("Const")
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
%0 = "std.constant" () {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("Const")
%1 = "tfl.sub"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%2 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
return %1, %2 : tensor<*xi32>, tensor<*xf32>
%cst = constant dense<1> : tensor<i32> loc("Const")
%0 = "tfl.sub"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
return %0, %1 : tensor<*xi32>, tensor<*xf32>
}

View File

@ -1,5 +1,6 @@
// RUN: tf-opt --tfl-legalize-tf-while %s -o - | FileCheck %s --dump-input-on-failure
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline --mlir-disable-inline-simplify | FileCheck %s --dump-input-on-failure --check-prefix=INLINE
// RUN: tf-opt --tfl-legalize-tf-while %s -o - --tfl-legalize-tf-while --inline | FileCheck %s --dump-input-on-failure --check-prefix=CANON
func @while_main(%arg0: tensor<?x256x256xf32>) -> (tensor<i32>, tensor<256x256xf32>, tensor<?x256x256xf32>) attributes {tf.entry_function = {inputs = "input", outputs = "Identity,Identity_1,Identity_2"}} {
%cst = constant dense<1.000000e+00> : tensor<256x256xf32>
@ -51,3 +52,25 @@ func @while_cond_10_frozen0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: t
// INLINE: yield
// INLINE: while_body
// INLINE: while_cond
// CANON-LABEL: func @while_main
// CANON-SAME: ([[VAL_0:%.*]]: tensor<?x256x256xf32>)
// CANON-SAME: (tensor<i32>, tensor<256x256xf32>, tensor<?x256x256xf32>)
// CANON: [[VAL_1:%.*]] = constant dense<1.000000e+00> : tensor<256x256xf32>
// CANON: [[VAL_2:%.*]] = constant dense<0> : tensor<i32>
// CANON: [[VAL_3:%.*]] = constant dense<10> : tensor<i32>
// CANON: [[VAL_4:%.*]] = constant dense<1> : tensor<i32>
// CANON: [[VAL_5:%.*]] = "tf.Const"() {value = dense<2.560000e+02> : tensor<256x256xf32>} : () -> tensor<?x?xf32>
// CANON: [[VAL_6:%.*]]:3 = "tfl.while"([[VAL_2]], [[VAL_2]], [[VAL_0]]) ( {
// CANON: ^bb0([[VAL_7:%.*]]: tensor<*xi32>, [[VAL_8:%.*]]: tensor<*xi32>, [[VAL_9:%.*]]: tensor<*xf32>):
// CANON: [[VAL_10:%.*]] = "tf.Less"([[VAL_8]], [[VAL_3]])
// CANON: "tfl.yield"([[VAL_10]]) : (tensor<*xi1>) -> ()
// CANON: }, {
// CANON: ^bb0([[VAL_11:%.*]]: tensor<*xi32>, [[VAL_12:%.*]]: tensor<*xi32>, [[VAL_13:%.*]]: tensor<*xf32>):
// CANON: [[VAL_14:%.*]] = "tf.AddV2"([[VAL_12]], [[VAL_4]])
// CANON: [[VAL_15:%.*]] = "tf.AddV2"([[VAL_13]], [[VAL_5]])
// CANON: [[VAL_16:%.*]] = "tf.AddV2"([[VAL_11]], [[VAL_4]])
// CANON: "tfl.yield"([[VAL_16]], [[VAL_14]], [[VAL_15]]) : (tensor<*xi32>, tensor<*xi32>, tensor<*xf32>) -> ()
// CANON: }) {is_stateless = true} : (tensor<i32>, tensor<i32>, tensor<?x256x256xf32>) -> (tensor<i32>, tensor<i32>, tensor<?x256x256xf32>)
// CANON: return [[VAL_17:%.*]]#1, [[VAL_1]], [[VAL_17]]#2 : tensor<i32>, tensor<256x256xf32>, tensor<?x256x256xf32>
// CANON: }

View File

@ -123,6 +123,17 @@ func @softmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
// CHECK: "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x16xf32>) -> tensor<8x16xf32>
}
func @softplus(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
// CHECK-LABEL: softplus
// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<f32>
// CHECK-NEXT: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[add:.*]] = "tfl.add"(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor<f32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32>
}
func @fakeQuantArgsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
return %0 : tensor<8x8x8x8xf32>
@ -739,6 +750,15 @@ func @matrix_diag_v3(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK: return [[VAL_6]] : tensor<8x16x16xf32>
}
func @matrix_set_diag(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%0 = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
// CHECK-LABEL: func @matrix_set_diag(
// CHECK: [[VAL_0:%.*]] = "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
// CHECK: return [[VAL_0]]
}
func @maximum(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
@ -1364,3 +1384,99 @@ func @reciprocal_i64(%arg0: tensor<8xi64>) -> tensor<8xi64> {
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi64>, tensor<8xi64>) -> tensor<8xi64>
// CHECK: return
}
func @random_uniform() -> tensor<2x5xf32> {
%0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
%1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32>
return %1 : tensor<2x5xf32>
// CHECK-LABEL: random_uniform
// CHECK: %[[CST:.*]] = constant dense
// CHECK: return %[[CST:.*]] : tensor<2x5xf32>
}
func @random_uniform_no_fold(%arg0: tensor<2xi32>) -> tensor<2x5xf32> {
%1 = "tf.RandomUniform"(%arg0) { seed = 0, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf32>
return %1 : tensor<2x5xf32>
// CHECK-LABEL: random_uniform_no_fold
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
}
func @random_uniform_no_fold2(%arg0: tensor<2xi32>) -> tensor<*xf32> {
%1 = "tf.RandomUniform"(%arg0) { seed = 1, seed2 = 2} : (tensor<2xi32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
// CHECK-LABEL: random_uniform_no_fold2
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
}
func @random_uniform_no_fold3() -> tensor<2x5xf64> {
%0 = "tf.Const"() { value = dense<[2, 5]> : tensor<2xi32> } : () -> tensor<2xi32>
%1 = "tf.RandomUniform"(%0) { seed = 1, seed2 = 0} : (tensor<2xi32>) -> tensor<2x5xf64>
return %1 : tensor<2x5xf64>
// CHECK-LABEL: random_uniform_no_fold3
// CHECK: %[[RANDOM:.*]] = "tf.RandomUniform"
}
func @LstmWithoutProjection(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x16xf32>) {
%1 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x28xf32>} : () -> tensor<16x28xf32>
%2 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x16xf32>} : () -> tensor<16x16xf32>
%3 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16xf32>} : () -> tensor<16xf32>
%4 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x16xf32>} : () -> tensor<1x16xf32>
%5 = "tf.Const"() {device = "", dtype = f32, value = dense<-1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
%6:3 = "tf.UnidirectionalSequenceLstm"(%arg, %1, %1, %1, %1, %2, %2, %2, %2, %3, %3, %3, %3, %3, %3, %3, %5, %5, %4, %4) {_tflite_input_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 18, 19], device = ""} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1x16xf32>, tensor<1x16xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<28x1x16xf32>)
return %6#2 : tensor<28x1x16xf32>
}
// CHECK: func @LstmWithoutProjection([[VAL_0:%.*]]: tensor<28x1x28xf32>) -> tensor<28x1x16xf32> {
// CHECK: [[VAL_1:%.*]] = constant dense<0.000000e+00> : tensor<16x28xf32>
// CHECK: [[VAL_2:%.*]] = constant dense<0.000000e+00> : tensor<16x16xf32>
// CHECK: [[VAL_3:%.*]] = constant dense<0.000000e+00> : tensor<16xf32>
// CHECK: [[VAL_4:%.*]] = constant dense<0.000000e+00> : tensor<1x16xf32>
// CHECK: [[VAL_5:%.*]] = constant unit
// CHECK: [[VAL_6:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_2]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_3]], [[VAL_5]], [[VAL_5]], [[VAL_4]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_5]], [[VAL_5]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x28xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, none, none, tensor<1x16xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x16xf32>
// CHECK: return [[VAL_6]] : tensor<28x1x16xf32>
// CHECK: }
func @LstmWithProjection(%arg: tensor<28x1x16xf32>) -> (tensor<28x1x8xf32>) {
%1 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x16xf32>} : () -> tensor<16x16xf32>
%2 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16x8xf32>} : () -> tensor<16x8xf32>
%3 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<16xf32>} : () -> tensor<16xf32>
%4 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x16xf32>} : () -> tensor<1x16xf32>
%5 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<8x16xf32>} : () -> tensor<8x16xf32>
%6 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x8xf32>} : () -> tensor<1x8xf32>
%7 = "tf.Const"() {device = "", dtype = f32, value = dense<-1.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32>
%8:3 = "tf.UnidirectionalSequenceLstm"(%arg, %1, %1, %1, %1, %2, %2, %2, %2, %7, %7, %7, %3, %3, %3, %3, %5, %7, %6, %4) {_tflite_input_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 18, 19], device = ""} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, tensor<1xf32>, tensor<1x8xf32>, tensor<1x16xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<28x1x8xf32>)
return %8#2 : tensor<28x1x8xf32>
}
// CHECK-LABEL: func @LstmWithProjection(
// CHECK-SAME: [[VAL_7:%.*]]: tensor<28x1x16xf32>) -> tensor<28x1x8xf32> {
// CHECK: [[VAL_8:%.*]] = constant dense<0.000000e+00> : tensor<16x16xf32>
// CHECK: [[VAL_9:%.*]] = constant dense<0.000000e+00> : tensor<16x8xf32>
// CHECK: [[VAL_10:%.*]] = constant dense<0.000000e+00> : tensor<16xf32>
// CHECK: [[VAL_11:%.*]] = constant dense<0.000000e+00> : tensor<1x16xf32>
// CHECK: [[VAL_12:%.*]] = constant dense<0.000000e+00> : tensor<8x16xf32>
// CHECK: [[VAL_13:%.*]] = constant dense<0.000000e+00> : tensor<1x8xf32>
// CHECK: [[VAL_14:%.*]] = constant unit
// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32>
// CHECK: return [[VAL_15]] : tensor<28x1x8xf32>
// CHECK: }
func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) {
%1 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<28x28xf32>} : () -> tensor<28x28xf32>
%2 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<28xf32>} : () -> tensor<28xf32>
%3 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x28xf32>} : () -> tensor<1x28xf32>
%4:2 = "tf.UnidirectionalSequenceRnn"(%arg, %1, %1, %2, %3) {_tflite_input_indices = [0, 1, 2, 3, 4], device = ""} : (tensor<28x1x28xf32>, tensor<28x28xf32>, tensor<28x28xf32>, tensor<28xf32>, tensor<1x28xf32>) -> (tensor<*xf32>, tensor<28x1x28xf32>)
return %4#1 : tensor<28x1x28xf32>
}
// CHECK: func @UnidirectionalRnn([[VAL_0:%.*]]: tensor<28x1x28xf32>) -> tensor<28x1x28xf32> {
// CHECK: [[VAL_1:%.*]] = constant dense<0.000000e+00> : tensor<28x28xf32>
// CHECK: [[VAL_2:%.*]] = constant dense<0.000000e+00> : tensor<28xf32>
// CHECK: [[VAL_3:%.*]] = constant dense<0.000000e+00> : tensor<1x28xf32>
// CHECK: [[VAL_4:%.*]] = "tfl.unidirectional_sequence_rnn"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_3]]) {fused_activation_function = "TANH", time_major = true} : (tensor<28x1x28xf32>, tensor<28x28xf32>, tensor<28x28xf32>, tensor<28xf32>, tensor<1x28xf32>) -> tensor<28x1x28xf32>
// CHECK: return [[VAL_4]] : tensor<28x1x28xf32>
// CHECK: }

View File

@ -1,17 +1,15 @@
// Test to verify translation & export work as intended with runtime.
// RUN: not mlir-tflite-runner --dump-interpreter-state %s 2>&1 | FileCheck %s --check-prefix ERROR --dump-input-on-failure
// RUN: tf-opt --mlir-print-debuginfo --canonicalize --tfl-while-loop-outline %s | mlir-tflite-runner --dump-interpreter-state 2>&1 | FileCheck %s --dump-input-on-failure
// ERROR: number of operands and results don't match
// Verify value computed:
// ----------------------
// CHECK: result: Tensor<type: FLOAT32, shape: 1, values: 96>
// CHECK: pconst: Tensor<type: INT32, shape: , values: 1>
// Verify tensors in interpreter state:
// ------------------------------------
// CHECK: Tensor 0 dec kTfLiteInt32 kTfLiteMmapRo 4 bytes
// CHECK: Tensor 0 pconst kTfLiteInt32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 1 N kTfLiteInt32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 3 std.constant kTfLiteInt32 kTfLiteMmapRo 4 bytes
@ -24,12 +22,12 @@
// ------------------------------------
// CHECK: Operator Builtin Code {{[0-9]*}} WHILE
func @main() -> tensor<1xf32>
attributes {tf.entry_function = {outputs = "result"}} {
func @main() -> (tensor<1xf32>, tensor<i32>)
attributes {tf.entry_function = {outputs = "result,pconst"}} {
%cst = constant dense<1> : tensor<i32> loc("dec")
%arg0 = constant dense<5> : tensor<i32> loc("N")
%arg1 = constant dense<3.0> : tensor<1xf32> loc("val")
%0:2 = "tfl.while"(%arg0, %arg1, %cst) ( {
%0:3 = "tfl.while"(%arg0, %arg1, %cst) ( {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>, %arg4: tensor<i32>):
%cst_0 = constant dense<0> : tensor<i32>
%1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
@ -40,7 +38,7 @@ func @main() -> tensor<1xf32>
(tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32>
"tfl.yield"(%1, %2, %arg4) : (tensor<*xi32>, tensor<*xf32>, tensor<i32>) -> ()
}) : (tensor<i32>, tensor<1xf32>, tensor<i32>) -> (tensor<i32>, tensor<1xf32>)
return %0#1 : tensor<1xf32>
}) : (tensor<i32>, tensor<1xf32>, tensor<i32>) -> (tensor<i32>, tensor<1xf32>, tensor<i32>)
return %0#1, %0#2 : tensor<1xf32>, tensor<i32>
}

View File

@ -34,14 +34,14 @@
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "tf.While",
// CHECK-NEXT: name: "tfl.while",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "tf.While:1",
// CHECK-NEXT: name: "tfl.while:1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
@ -193,22 +193,27 @@
// CHECK-NEXT: }
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// While %arg0 is greater than zero, element wise add %arg1 with itself.
%0:2 = "tf.While"(%arg0, %arg1) {
cond = @cond, body = @body, is_stateless = false
} : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
%0:2 = "tfl.while"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
%1 = call @cond(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> tensor<i1>
"tfl.yield"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
%1:2 = call @body(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>)
"tfl.yield"(%1#0, %1#1) : (tensor<*xi32>, tensor<*xf32>) -> ()
}) {is_stateless = false} : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
return %0#1 : tensor<1xf32>
}
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
%0 = "std.constant" () {value = dense<0> : tensor<i32>} : () -> tensor<i32> loc("Const")
%1 = "tfl.greater"(%arg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
return %1 : tensor<i1>
%cst = constant dense<0> : tensor<i32> loc("Const")
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
%0 = "std.constant" () {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("Const")
%1 = "tfl.sub"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%2 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
return %1, %2 : tensor<*xi32>, tensor<*xf32>
%cst = constant dense<1> : tensor<i32> loc("Const")
%0 = "tfl.sub"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
return %0, %1 : tensor<*xi32>, tensor<*xf32>
}

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -split-input-file -verify-diagnostics %s | FileCheck %s --dump-input-on-failure
// RUN: tf-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s --dump-input-on-failure
// Unary math ops
// -----
@ -593,6 +593,21 @@ func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>,
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testLstmQuantizedType
func @testLstmQuantizedType(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
%cst = constant unit
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ( {
}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
// CHECK: %[[RES0:.*]] = constant unit
// CHECK: %[[RES1:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %[[RES0]], %[[RES0]], %[[RES0]], %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ( {
// CHECK-NEXT: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
// CHECK: return %[[RES1]]
}
// -----
// CHECK-LABEL: testLstm
@ -878,6 +893,14 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// -----
func @packUnranked(%arg0: tensor<2xi32>, %arg1: tensor<*xi32>) -> tensor<2x2xi32> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<*xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32>
@ -1632,6 +1655,7 @@ func @testSplitOpWithMismatchTensorTypeNonSplitDim(%arg0 : tensor<16x4xf32>) ->
// -----
// CHECK-LABEL:testSplitOpWithValidTensorType
func @testSplitOpWithValidTensorType(%arg0 : tensor<16x4xf32>) -> (tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>, tensor<16x2xf32>) {
%split_dim_0 = constant dense<0> : tensor<i32>
%0, %1 = "tfl.split"(%split_dim_0, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16x4xf32>) -> (tensor<8x4xf32>, tensor<8x4xf32>)
@ -1639,6 +1663,9 @@ func @testSplitOpWithValidTensorType(%arg0 : tensor<16x4xf32>) -> (tensor<8x4xf3
%2, %3 = "tfl.split"(%split_dim_1, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16x4xf32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
%split_dim_2 = constant dense<1> : tensor<1xi32>
%4, %5 = "tfl.split"(%split_dim_2, %arg0) {num_splits = 2 : i32} : (tensor<1xi32>, tensor<16x4xf32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
%6:2 = "tfl.split"(%split_dim_2, %arg0) {num_splits = 2 : i32} : (tensor<1xi32>, tensor<16x4xf32>) -> (tensor<16x2xf32>, tensor<16x?xf32>)
%7:2 = "tfl.split"(%split_dim_2, %arg0) {num_splits = 2 : i32} : (tensor<1xi32>, tensor<16x4xf32>) -> (tensor<?x2xf32>, tensor<16x?xf32>)
%8:2 = "tfl.split"(%split_dim_2, %arg0) {num_splits = 2 : i32} : (tensor<1xi32>, tensor<16x4xf32>) -> (tensor<16x2xf32>, tensor<*xf32>)
return %0, %1, %2, %3, %4 : tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>, tensor<16x2xf32>
}
@ -1984,3 +2011,32 @@ func @testDensify(%arg0: tensor<? x f32>) -> tensor<? x f32> {
%0 = "tfl.densify"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// -----
func @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
%cst = constant dense<0> : tensor<i32> loc("Const")
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @WhileOp_body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
%cst = constant dense<1> : tensor<i32> loc("Const1")
%0 = "tfl.sub"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
return %0, %1 : tensor<*xi32>, tensor<*xf32>
}
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<i32> {
// expected-error @+1 {{number of operands does not match number of results}}
%0:1 = "tfl.while"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
%1 = call @WhileOp_cond(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> tensor<i1>
"tfl.yield"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
%1:2 = call @WhileOp_body(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>)
"tfl.yield"(%1#0, %1#1) : (tensor<*xi32>, tensor<*xf32>) -> ()
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>)
return %0#0 : tensor<i32>
}

View File

@ -717,6 +717,31 @@ func @expandDimsToReshape(%arg0: tensor<6x6x256xf32>) -> tensor<6x6x256x1xf32> {
// CHECK: return %[[RESULT]]
}
// CHECK-LABEL: convertTrivialTransposeToReshape
func @convertTrivialTransposeToReshape(%arg0: tensor<6x6x256x1xf32>) -> tensor<1x6x6x256xf32> {
%cst = constant dense<[3, 0, 1, 2]> : tensor<4xi32>
%0 = "tfl.transpose"(%arg0, %cst) : (tensor<6x6x256x1xf32>, tensor<4xi32>) -> tensor<1x6x6x256xf32>
return %0 : tensor<1x6x6x256xf32>
// CHECK: [[CONST:.*]] = constant dense<[1, 6, 6, 256]> : tensor<4xi32>
// CHECK: %[[RESULT:.*]] = "tfl.reshape"(%arg0, %[[CONST:.*]]) : (tensor<6x6x256x1xf32>, tensor<4xi32>) -> tensor<1x6x6x256xf32>
// CHECK: return %[[RESULT]]
}
// CHECK-LABEL: doNotConvertNonTrivialTransposeToReshape
func @doNotConvertNonTrivialTransposeToReshape(%arg0: tensor<6x6x256x1xf32>) -> tensor<1x6x6x256xf32> {
// Note: The dimension 0 and 1 are swapped, so it's not trivial
// (elements are not in the same order).
%cst = constant dense<[3, 1, 0, 2]> : tensor<4xi32>
%0 = "tfl.transpose"(%arg0, %cst) : (tensor<6x6x256x1xf32>, tensor<4xi32>) -> tensor<1x6x6x256xf32>
return %0 : tensor<1x6x6x256xf32>
// CHECK: [[CONST:.*]] = constant dense<[3, 1, 0, 2]> : tensor<4xi32>
// CHECK: %[[RESULT:.*]] = "tfl.transpose"(%arg0, %[[CONST:.*]])
// CHECK: return %[[RESULT]]
}
// CHECK-LABEL: Relu1
func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%cst = constant dense<-1.0> : tensor<f32>

View File

@ -96,3 +96,40 @@ func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Verify unused if with functions without side-effects are removed.
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
%cst_0 = constant dense<1.000000e+00> : tensor<f32>
%cst_1 = constant dense<0.000000e+00> : tensor<8xf32>
%cst_2 = constant dense<0.000000e+00> : tensor<8x3x3x3xf32>
%0 = "tfl.sub"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3x15x14x3xf32>, tensor<f32>) -> tensor<3x15x14x3xf32>
%1 = "tfl.greater_equal"(%arg0, %0) : (tensor<3x15x14x3xf32>, tensor<3x15x14x3xf32>) -> tensor<3x15x14x3xi1>
%2 = "tf.All"(%1, %cst) {Tidx = i32, device = "/device:CPU:0", keep_dims = false} : (tensor<3x15x14x3xi1>, tensor<4xi32>) -> tensor<i1>
%3 = "tf.If"(%2, %2, %arg0, %0) {Tcond = i1,
else_branch = @_functionalize_if_else_branch_00, is_stateless = false,
then_branch = @_functionalize_if_then_branch_00} :
(tensor<i1>, tensor<i1>, tensor<3x15x14x3xf32>, tensor<3x15x14x3xf32>) -> tensor<i1>
%4 = "tfl.conv_2d"(%arg0, %cst_2, %cst_1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<3x15x14x3xf32>, tensor<8x3x3x3xf32>, tensor<8xf32>) -> tensor<3x15x14x8xf32>
return %4 : tensor<3x15x14x8xf32>
}
func @_functionalize_if_else_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
%cst = constant dense<false> : tensor<i1>
return %cst : tensor<i1>
}
func @_functionalize_if_then_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
%cst = constant dense<true> : tensor<i1>
return %cst : tensor<i1>
}
// CHECK: func @main
// CHECK-NOT: tf.If
// CHECK: return
// CHECK-NOT: func else_branch
// CHECK-NOT: func then_branch

View File

@ -154,7 +154,7 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3
// -----
module {
func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
@ -165,7 +165,7 @@ func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10x
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}
// CHECK: func @inference_standard_lstm_7410([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<?x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
@ -181,7 +181,127 @@ func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10x
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK: [[VAL_19:%.*]] = constant unit
// CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( {
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
// CHECK: return [[VAL_21:%.*]] : tensor<?x8x10xf32>
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x?x10xf32>
// CHECK: return [[VAL_21:%.*]] : tensor<8x?x10xf32>
// CHECK: }
}
// -----
module {
func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} {
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
%1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
%3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
%4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}
// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} {
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32>
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
// CHECK: [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK: [[VAL_21:%.*]] = constant unit
// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( {
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
// CHECK: [[VAL_23:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
// CHECK: [[VAL_24:%.*]] = "tf.Transpose"([[VAL_25:%.*]], [[VAL_23]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
// CHECK: return [[VAL_24]] : tensor<8x8x10xf32>
// CHECK: }
}
// -----
module {
func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
%3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
%4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x?x10xf32>
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}
// CHECK: func @inference_standard_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
// CHECK: [[VAL_6:%.*]] = constant dense<0> : tensor<1xi32>
// CHECK: [[VAL_7:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_6]]) : (tensor<?x8x8xf32>, tensor<1xi32>) -> tensor<?x8x8xf32>
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
// CHECK: [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK: [[VAL_21:%.*]] = constant unit
// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( {
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x?x10xf32>
// CHECK: return [[VAL_23:%.*]] : tensor<8x?x10xf32>
// CHECK: }
}
// -----
module {
func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
%1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
%3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
%4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}
// CHECK: func @inference_standard_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32>
// CHECK: [[VAL_8:%.*]] = constant dense<0> : tensor<1xi32>
// CHECK: [[VAL_9:%.*]] = "tf.ReverseV2"([[VAL_7]], [[VAL_8]]) : (tensor<8x8x8xf32>, tensor<1xi32>) -> tensor<8x8x8xf32>
// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_10]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
// CHECK: [[VAL_12:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_13:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_12]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
// CHECK: [[VAL_14:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_16:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_14]], [[VAL_15]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_19:%.*]]:4 = "tf.SplitV"([[VAL_13]], [[VAL_17]], [[VAL_18]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK: [[VAL_20:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_22:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_20]], [[VAL_21]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK: [[VAL_23:%.*]] = constant unit
// CHECK: [[VAL_24:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_16]]#0, [[VAL_16]]#1, [[VAL_16]]#2, [[VAL_16]]#3, [[VAL_19]]#0, [[VAL_19]]#1, [[VAL_19]]#2, [[VAL_19]]#3, [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_22]]#0, [[VAL_22]]#1, [[VAL_22]]#2, [[VAL_22]]#3, [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) ( {
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
// CHECK: [[VAL_25:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
// CHECK: [[VAL_26:%.*]] = "tf.Transpose"([[VAL_27:%.*]], [[VAL_25]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
// CHECK: return [[VAL_26]] : tensor<8x8x10xf32>
// CHECK: }
}

View File

@ -622,3 +622,16 @@ func @QuantizeSharedBiases2(
// CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]]
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
}
// CHECK-LABEL: ReturnQuantizedResult
func @ReturnQuantizedResult(%arg0: tensor<1x224x224x3xf32>, %arg1: tensor<32x3x3x3xf32>, %arg2: tensor<32xf32>) -> (tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>) {
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
%2 = "tfl.dequantize"(%1) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>) -> (tensor<1x112x112x32xf32>)
return %0, %2 : tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>
// CHECK: %[[dw:.*]] = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2)
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[dw]])
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
// CHECK: return %[[dq]], %[[dq]]
}

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -tfl-prepare-tf %s | FileCheck %s
// RUN: tf-opt -tfl-prepare-tf %s | FileCheck %s --dump-input-on-failure
func @conv(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<256x3x32x32xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x16x30x30xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
^bb0(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<256x3x32x32xf32>) :
@ -117,6 +117,37 @@ func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor
// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
}
func @batchNormWithGlobalNormalization(
%t:tensor<1x10x10x3xf32>, %m:tensor<3xf32>, %v:tensor<3xf32>, %beta:tensor<3xf32>, %gamma:tensor<3xf32>) -> (tensor<1x10x10x3xf32>) {
%0 = "tf.BatchNormWithGlobalNormalization"(%t, %m, %v, %beta, %gamma) {T = "tfdtype$DT_FLOAT", variance_epsilon = 0.001 : f32, scale_after_normalization = false} : (tensor<1x10x10x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<1x10x10x3xf32>)
return %0 : tensor<1x10x10x3xf32>
// CHECK-LABEL: batchNormWithGlobalNormalization
// CHECK: %[[EPSILON:.*]] = constant dense<1.000000e-03>
// CHECK: %[[VARIANCE:.*]] = "tf.Add"(%[[ARG_V:.*]], %[[EPSILON]])
// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[VARIANCE]])
// CHECK: %[[MUL1:.*]] = "tf.Mul"(%[[ARG_T:.*]], %[[RSQRT]])
// CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[ARG_M:.*]], %[[RSQRT]])
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG_BETA:.*]], %[[MUL2]])
// CHECK: %[[RESULT:.*]] = "tf.Add"(%[[MUL1]], %[[SUB]])
// CHECK: return %[[RESULT]]
}
func @batchNormWithGlobalNormalizationWithScaleAfterNormalization(
%t:tensor<1x10x10x3xf32>, %m:tensor<3xf32>, %v:tensor<3xf32>, %beta:tensor<3xf32>, %gamma:tensor<3xf32>) -> (tensor<1x10x10x3xf32>) {
%0 = "tf.BatchNormWithGlobalNormalization"(%t, %m, %v, %beta, %gamma) {T = "tfdtype$DT_FLOAT", variance_epsilon = 0.001 : f32, scale_after_normalization = true} : (tensor<1x10x10x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<1x10x10x3xf32>)
return %0 : tensor<1x10x10x3xf32>
// CHECK-LABEL: batchNormWithGlobalNormalizationWithScaleAfterNormalization
// CHECK: %[[EPSILON:.*]] = constant dense<1.000000e-03>
// CHECK: %[[VARIANCE:.*]] = "tf.Add"(%[[ARG_V:.*]], %[[EPSILON]])
// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[VARIANCE]])
// CHECK: %[[MUL0:.*]] = "tf.Mul"(%[[RSQRT]], %[[ARG_GAMMA:.*]])
// CHECK: %[[MUL1:.*]] = "tf.Mul"(%[[ARG_T:.*]], %[[MUL0]])
// CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[ARG_M:.*]], %[[MUL0]])
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG_BETA:.*]], %[[MUL2]])
// CHECK: %[[RESULT:.*]] = "tf.Add"(%[[MUL1]], %[[SUB]])
// CHECK: return %[[RESULT]]
}
// CHECK-LABEL: fakeQuantPerChannelForActivation
func @fakeQuantPerChannelForActivation(%arg0: tensor<8x3xf32>) -> (tensor<8x3xf32>) {
%arg1 = constant dense<[0.0, -1.0, 1.0]> : tensor<3xf32>
@ -422,6 +453,30 @@ func @placeholder_with_default(%arg0: tensor<3xf32>) -> tensor<3xf32> {
// CHECK: return %arg0 : tensor<3xf32>
}
// CHECK-LABEL: @StridedSliceEllipsisMaskBefore
func @StridedSliceEllipsisMaskBefore(%arg0: tensor<21x15x7xf32>) -> tensor<21x15x2xf32> {
%cst = constant dense<0> : tensor<2xi32>
%cst_0 = constant dense<1> : tensor<2xi32>
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 0 : i64, ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<21x15x2xf32>
return %0 : tensor<21x15x2xf32>
// CHECK: %[[CST:.*]] = constant dense<0> : tensor<3xi32>
// CHECK: %[[CST_0:.*]] = constant dense<1> : tensor<3xi32>
// CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 3 : i64, ellipsis_mask = 0 : i64, end_mask = 3 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<21x15x2xf32>
}
// CHECK-LABEL: @StridedSliceEllipsisMaskAfter
func @StridedSliceEllipsisMaskAfter(%arg0: tensor<21x15x7xf32>) -> tensor<5x15x7xf32> {
%cst = constant dense<0> : tensor<2xi32>
%cst_0 = constant dense<1> : tensor<2xi32>
%0 = "tf.StridedSlice"(%arg0, %cst, %cst, %cst_0) {begin_mask = 0 : i64, ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<5x15x7xf32>
return %0 : tensor<5x15x7xf32>
// CHECK: %[[CST:.*]] = constant dense<0> : tensor<3xi32>
// CHECK: %[[CST_0:.*]] = constant dense<1> : tensor<3xi32>
// CHECK: %[[STRIDED_SLICE:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST]], %[[CST_0]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<21x15x7xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<5x15x7xf32>
}
// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
%cst = constant dense<0> : tensor<4xi32>
@ -456,3 +511,34 @@ func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64
%1 = "tf.StridedSlice"(%0, %cst, %cst, %cst_0) {Index = i32, T = f32, _output_shapes = ["tfshape$dim { size: 1 } dim { size: 4 } dim { size: 64 } dim { size: 64 }"], begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 1 : i64, shrink_axis_mask = 0 : i64} : (tensor<4x64x64xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x4x64x64xf32>
return %1 : tensor<1x4x64x64xf32>
}
// CHECK-LABEL: @MatrixSetDiagV2Conversion
func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<0> : tensor<i32>
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @MatrixSetDiagV2NonZeroK
func @MatrixSetDiagV2NonZeroK(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<1> : tensor<i32>
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV2"(%arg0, %arg1, %[[CST]]) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @MatrixSetDiagV3Conversion
func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<0> : tensor<i32>
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
// CHECK: return %[[RES]]
}

View File

@ -2,39 +2,44 @@
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify | FileCheck --check-prefix=DEBUG %s
// CHECK-LABEL: QuantizeFloatConst
func @QuantizeFloatConst() -> tensor<f32> {
func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
%0 = constant dense<-0.1> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
%2 = "tfl.dequantize"(%1) : (tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<f32>
return %2 : tensor<f32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<0> : tensor<2x2xi8>}
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
// CHECK: return %[[dq]] : tensor<f32>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<0> : tensor<2x2xi8>}
// CHECK: return %[[cst]]
}
// CHECK-LABEL: QuantizeDenseFloatConst
func @QuantizeDenseFloatConst() -> tensor<2x2xf32> {
func @QuantizeDenseFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
%0 = constant dense<[[-0.1, 1.0], [1.0, 3.0]]> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<{{\[\[}}0, -1], {{\[}}-1, -1]]> : tensor<2x2xi8>}
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
// CHECK: return %[[dq]] : tensor<2x2xf32>
// CHECK: return %[[cst]]
}
// CHECK-LABEL: QuantizeSplatFloatConst
func @QuantizeSplatFloatConst() -> tensor<2x2xf32> {
func @QuantizeSplatFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
%0 = constant dense<3.0> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
// CHECK: return %[[cst]]
}
// CHECK-LABEL: NotQuantizeFloatConst
func @NotQuantizeFloatConst() -> tensor<2x2xf32> {
%0 = constant dense<-0.1> : tensor<2x2xf32>
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
%2 = "tfl.dequantize"(%1) : (tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
// CHECK: %[[cst:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>}
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[cst]])
// CHECK: return %[[dq]] : tensor<2x2xf32>
// CHECK: %[[cst:.*]] = constant dense<-1.000000e-01> : tensor<2x2xf32>
// CHECK: return %[[cst]] : tensor<2x2xf32>
}
// CHECK-LABEL: DequantizeAndQuantize
@ -71,7 +76,7 @@ func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>
// DEBUG: %[[act:.*]] = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
// DEBUG: %[[f_conv:.*]] = "tfl.conv_2d"(%[[act]], %[[wt]], %[[bias]])
// DEBUG: %[[q_conv:.*]] = "tfl.conv_2d"
// DEBUG: "tfl.NumericVerify"(%[[q_conv]], %[[f_conv]]) {tolerance = 1.000000e-01 : f32}
// DEBUG: "tfl.NumericVerify"(%[[q_conv]], %[[f_conv]]) {tolerance = 5.000000e+00 : f32}
// DEBUG: return %[[q_conv]] : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
}
@ -236,8 +241,8 @@ func @QuantizeSplit(%arg: tensor<4x!quant.uniform<u8:f32, 1.0>>, %cst: tensor<i3
// DEUBG: %[[f_split:.*]]:2 = "tfl.split"
// DEUBG: %[[q_split:.*]]:2 = "tfl.split"
// DEUBG: "tfl.NumericVerify"(%[[q_split]]#1, %[[f_split]]#1) {tolerance = 1.000000e-01 : f32}
// DEUBG: "tfl.NumericVerify"(%[[q_split]]#0, %[[f_split]]#0) {tolerance = 1.000000e-01 : f32}
// DEUBG: "tfl.NumericVerify"(%[[q_split]]#1, %[[f_split]]#1) {tolerance = 5.000000e+00 : f32}
// DEUBG: "tfl.NumericVerify"(%[[q_split]]#0, %[[f_split]]#0) {tolerance = 5.000000e+00 : f32}
}
// CHECK-LABEL: QuantizeSplitUnusedResults

Some files were not shown because too many files have changed in this diff Show More