Merge branch 'master' into 12829-extract_glimpse
This commit is contained in:
commit
158d128323
16
.bazelrc
16
.bazelrc
@ -143,6 +143,11 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl --define=build_with_mkl_dnn_v1_only=true
|
||||
build:mkl -c opt
|
||||
|
||||
# config to build OneDNN backend with a user specified threadpool.
|
||||
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
|
||||
build:mkl_threadpool -c opt
|
||||
# This config refers to building with CUDA available. It does not necessarily
|
||||
# mean that we build CUDA op kernels.
|
||||
build:using_cuda --define=using_cuda=true
|
||||
@ -235,10 +240,15 @@ build:c++17 --cxxopt=-std=c++1z
|
||||
build:c++17 --cxxopt=-stdlib=libc++
|
||||
build:c++1z --config=c++17
|
||||
|
||||
# Enable using platform specific build settings
|
||||
# Enable using platform specific build settings, except when cross-compiling for
|
||||
# mobile platforms.
|
||||
build --enable_platform_specific_config
|
||||
build:android --noenable_platform_specific_config
|
||||
build:ios --noenable_platform_specific_config
|
||||
|
||||
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
|
||||
build:android --copt=-w
|
||||
build:ios --copt=-w
|
||||
build:linux --copt=-w
|
||||
build:macos --copt=-w
|
||||
build:windows --copt=/w
|
||||
@ -258,6 +268,10 @@ build:macos --define=INCLUDEDIR=$(PREFIX)/include
|
||||
# TF_SYSTEM_LIBS do not work on windows.
|
||||
|
||||
# By default, build TF in C++ 14 mode.
|
||||
build:android --cxxopt=-std=c++14
|
||||
build:android --host_cxxopt=-std=c++14
|
||||
build:ios --cxxopt=-std=c++14
|
||||
build:ios --host_cxxopt=-std=c++14
|
||||
build:linux --cxxopt=-std=c++14
|
||||
build:linux --host_cxxopt=-std=c++14
|
||||
build:macos --cxxopt=-std=c++14
|
||||
|
29
.github/bot_config.yml
vendored
Normal file
29
.github/bot_config.yml
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
#
|
||||
# THIS IS A GENERATED DOCKERFILE.
|
||||
#
|
||||
# This file was assembled from multiple pieces, whose use is documented
|
||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
||||
# for more information.
|
||||
|
||||
# A list of assignees
|
||||
assignees:
|
||||
- amahendrakar
|
||||
- ravikyram
|
||||
- Saduf2019
|
||||
# A list of assignees for
|
||||
compiler_assignees:
|
||||
- joker-eph
|
22
README.md
22
README.md
@ -103,17 +103,17 @@ open-source software development:
|
||||
|
||||
### Official Builds
|
||||
|
||||
Build Type | Status | Artifacts
|
||||
------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Linux XLA** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
|
||||
**macOS** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
Build Type | Status | Artifacts
|
||||
------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||
**Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Linux XLA** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
|
||||
**macOS** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
|
||||
**Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
|
||||
### Community Supported Builds
|
||||
|
||||
|
25
RELEASE.md
25
RELEASE.md
@ -3,6 +3,31 @@
|
||||
## Breaking Changes
|
||||
* `tf.image.extract_glimpse` has been updated to correctly process the case where `centered=False` and `normalized=False`. This is a breaking change as the output is different from (incorrect) previous versions. Note this breaking change only impacts `tf.image.extract_glimpse` and `tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of `tf.compat.v1.image.extract_glimpse` does not change. The behavior of exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved models will not be impacted.
|
||||
|
||||
# Release 2.1.1
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
* Fixes a versioning bug which causes Keras layers from TF 1.x to be used instead of those from TF 2.x
|
||||
|
||||
# Release 2.0.2
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
|
||||
# Release 1.15.3
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
|
||||
# Release 2.2.0
|
||||
|
||||
TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update).
|
||||
|
@ -64,7 +64,7 @@ your model, and we recommend you run the TensorFlow process in a sandbox.
|
||||
|
||||
It is possible to write models that are secure in a sense that they can safely
|
||||
process untrusted inputs assuming there are no bugs. There are two main reasons
|
||||
to not rely on this: first, it is easy to write models which must not be exposed
|
||||
to not rely on this: First, it is easy to write models which must not be exposed
|
||||
to untrusted inputs, and second, there are bugs in any software system of
|
||||
sufficient complexity. Letting users control inputs could allow them to trigger
|
||||
bugs either in TensorFlow or in dependent libraries.
|
||||
@ -149,7 +149,7 @@ attack (or worse). Because TensorFlow behaves correctly, this is not a
|
||||
vulnerability in TensorFlow (although it would be a vulnerability of this
|
||||
hypothetical system).
|
||||
|
||||
As a general rule, it is incorrect behavior for Tensorflow to access memory it
|
||||
As a general rule, it is incorrect behavior for TensorFlow to access memory it
|
||||
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
|
||||
such behaviors constitute a vulnerability.
|
||||
|
||||
|
25
configure.py
25
configure.py
@ -144,7 +144,7 @@ def write_to_bazelrc(line):
|
||||
|
||||
|
||||
def write_action_env_to_bazelrc(var_name, var):
|
||||
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
|
||||
write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var)))
|
||||
|
||||
|
||||
def run_shell(cmd, allow_non_zero=False, stderr=None):
|
||||
@ -205,7 +205,7 @@ def setup_python(environ_cp):
|
||||
# Get PYTHON_BIN_PATH, default is the current running python.
|
||||
default_python_bin_path = sys.executable
|
||||
ask_python_bin_path = ('Please specify the location of python. [Default is '
|
||||
'%s]: ') % default_python_bin_path
|
||||
'{}]: ').format(default_python_bin_path)
|
||||
while True:
|
||||
python_bin_path = get_from_env_or_user_or_default(environ_cp,
|
||||
'PYTHON_BIN_PATH',
|
||||
@ -215,9 +215,10 @@ def setup_python(environ_cp):
|
||||
if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
|
||||
break
|
||||
elif not os.path.exists(python_bin_path):
|
||||
print('Invalid python path: %s cannot be found.' % python_bin_path)
|
||||
print('Invalid python path: {} cannot be found.'.format(python_bin_path))
|
||||
else:
|
||||
print('%s is not executable. Is it the python binary?' % python_bin_path)
|
||||
print('{} is not executable. Is it the python binary?'.format(
|
||||
python_bin_path))
|
||||
environ_cp['PYTHON_BIN_PATH'] = ''
|
||||
|
||||
# Convert python path to Windows style before checking lib and version
|
||||
@ -236,7 +237,7 @@ def setup_python(environ_cp):
|
||||
default_python_lib_path = python_lib_paths[0]
|
||||
python_lib_path = get_input(
|
||||
'Please input the desired Python library path to use. '
|
||||
'Default is [%s]\n' % python_lib_paths[0])
|
||||
'Default is [{}]\n'.format(python_lib_paths[0]))
|
||||
if not python_lib_path:
|
||||
python_lib_path = default_python_lib_path
|
||||
environ_cp['PYTHON_LIB_PATH'] = python_lib_path
|
||||
@ -252,7 +253,7 @@ def setup_python(environ_cp):
|
||||
# Set-up env variables used by python_configure.bzl
|
||||
write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
|
||||
write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
|
||||
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
|
||||
write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path))
|
||||
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
|
||||
|
||||
# If choosen python_lib_path is from a path specified in the PYTHONPATH
|
||||
@ -266,7 +267,7 @@ def setup_python(environ_cp):
|
||||
with open(
|
||||
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
|
||||
'w') as f:
|
||||
f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
|
||||
f.write('export PYTHON_BIN_PATH="{}"'.format(python_bin_path))
|
||||
|
||||
|
||||
def reset_tf_configure_bazelrc():
|
||||
@ -320,11 +321,12 @@ def get_var(environ_cp,
|
||||
Raise the error to avoid infinitely looping.
|
||||
"""
|
||||
if not question:
|
||||
question = 'Do you wish to build TensorFlow with %s support?' % query_item
|
||||
question = 'Do you wish to build TensorFlow with {} support?'.format(
|
||||
query_item)
|
||||
if not yes_reply:
|
||||
yes_reply = '%s support will be enabled for TensorFlow.' % query_item
|
||||
yes_reply = '{} support will be enabled for TensorFlow.'.format(query_item)
|
||||
if not no_reply:
|
||||
no_reply = 'No %s' % yes_reply
|
||||
no_reply = 'No {}'.format(yes_reply)
|
||||
|
||||
yes_reply += '\n'
|
||||
no_reply += '\n'
|
||||
@ -368,7 +370,7 @@ def get_var(environ_cp,
|
||||
print(no_reply)
|
||||
var = False
|
||||
else:
|
||||
print('Invalid selection: %s' % user_input_origin)
|
||||
print('Invalid selection: {}'.format(user_input_origin))
|
||||
return var
|
||||
|
||||
|
||||
@ -1385,7 +1387,6 @@ def main():
|
||||
# Windows.
|
||||
environ_cp['TF_DOWNLOAD_CLANG'] = '0'
|
||||
environ_cp['TF_NEED_MPI'] = '0'
|
||||
environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0'
|
||||
|
||||
if is_macos():
|
||||
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||
|
@ -530,6 +530,13 @@ package_group(name = "ndarray_tensor_allow_list")
|
||||
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
||||
package_group(name = "composite_tensor_whitelist")
|
||||
|
||||
# Packages that use private types symbols, until they are exported.
|
||||
# TODO(b/154650521) Remove.
|
||||
package_group(
|
||||
name = "types_whitelist",
|
||||
packages = ["//learning/deepmind/tensorflow/replicator/..."],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "intel_binary_blob",
|
||||
data = if_mkl_ml(
|
||||
|
@ -85,7 +85,7 @@ tf_cuda_library(
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//tensorflow:chromiumos": [
|
||||
":tf_attrtype",
|
||||
@ -182,7 +182,7 @@ tf_cuda_library(
|
||||
":tf_status_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tf_status",
|
||||
@ -219,7 +219,7 @@ tf_cuda_library(
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
@ -232,12 +232,13 @@ cc_library(
|
||||
srcs = ["tf_status.cc"],
|
||||
hdrs = ["tf_status.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
deps = [
|
||||
":tf_status_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tf_status_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
@ -259,10 +260,15 @@ cc_library(
|
||||
name = "tensor_interface",
|
||||
hdrs = ["tensor_interface.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -272,7 +278,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
@ -286,16 +292,17 @@ cc_library(
|
||||
srcs = ["tf_tensor.cc"],
|
||||
hdrs = ["tf_tensor.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
deps = [
|
||||
":tensor_interface",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
":tf_tensor_internal",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tensor_interface",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
":tf_tensor_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -311,14 +318,15 @@ tf_cuda_library(
|
||||
"tf_tensor_internal.h",
|
||||
],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = select({
|
||||
deps = [
|
||||
":tensor_interface",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tensor_interface",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:casts",
|
||||
@ -386,8 +394,14 @@ tf_cuda_library(
|
||||
deps = [
|
||||
":tf_status",
|
||||
":tf_status_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
@ -426,7 +440,7 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
@ -457,7 +471,7 @@ tf_cuda_library(
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
":c_api_internal",
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api_internal",
|
||||
@ -484,7 +498,7 @@ tf_cuda_library(
|
||||
":tf_status_helper",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -35,7 +35,7 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":context_interface",
|
||||
@ -319,6 +319,7 @@ tf_cuda_cc_test(
|
||||
tags = [
|
||||
"noguitar", # TODO(b/155445984): flaky
|
||||
#"guitar",
|
||||
"notap", # TODO(b/156981931): flaky
|
||||
"multi_gpu",
|
||||
],
|
||||
deps = [
|
||||
@ -357,10 +358,13 @@ tf_cuda_cc_test(
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:function_optimization_registry",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -412,7 +416,7 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api",
|
||||
@ -448,6 +452,8 @@ tf_cuda_library(
|
||||
"//conditions:default": [],
|
||||
}) + [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||
|
@ -899,9 +899,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
status->status = tensorflow::Status::OK();
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->SyncExecutors();
|
||||
status->status = tensorflow::unwrap(ctx)->AsyncWait();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -924,7 +922,7 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||
context->GetDevicePlacementPolicy());
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
||||
tensorflow::Tensor tensor;
|
||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
// placed in memory of different devices or remote address spaces.
|
||||
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
|
||||
TF_Status* status);
|
||||
// Indicates that the caller will not be using `h` any more.
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
|
||||
|
@ -50,6 +50,13 @@ tensorflow::ServerDef GetServerDef(int num_tasks) {
|
||||
return GetServerDef("localhost", num_tasks);
|
||||
}
|
||||
|
||||
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
|
||||
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->at(task_index) =
|
||||
tensorflow::strings::StrCat("localhost:", port);
|
||||
}
|
||||
|
||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||
const std::vector<float>& expected_values) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
@ -101,6 +108,22 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
// Read the value of variable `var` and save it into `out_value`.
|
||||
void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var,
|
||||
TFE_TensorHandle** out_value) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op, var, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, out_value, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
@ -243,6 +266,102 @@ TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
|
||||
TestRemoteExecuteUpdateServerDef(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
|
||||
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
|
||||
EXPECT_NE(var_handle0, nullptr);
|
||||
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
|
||||
EXPECT_NE(var_handle1, nullptr);
|
||||
|
||||
TFE_TensorHandle* value_handle = nullptr;
|
||||
ReadVariable(ctx, var_handle1, &value_handle);
|
||||
CheckTFE_TensorHandleHasFloats(value_handle, {2});
|
||||
TFE_DeleteTensorHandle(value_handle);
|
||||
|
||||
// Start a new worker to replace task:1
|
||||
ReplaceTaskInServerDef(&server_def, 1);
|
||||
server_def.set_task_index(1);
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
// Update server def to replace the remote device with the device info on the
|
||||
// new worker (different incarnation ID).
|
||||
server_def.set_task_index(0);
|
||||
string serialized_update = server_def.SerializeAsString();
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||
serialized_update.size(), status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// The device of var_handle0 is local device which is the same before and
|
||||
// after cluster update. Remove resource with valid device should succeed.
|
||||
TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle0, status);
|
||||
TFE_OpSetDevice(op, dev0_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
|
||||
// The device of var_handle1 is remote device, which was replaced during
|
||||
// cluster update. Removing resource with invalid device should fail
|
||||
// gracefully (i.e., with error status) instead of crashing with segfaults.
|
||||
op = TFE_NewOp(ctx, "DestroyResourceOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle1, status);
|
||||
TFE_OpSetDevice(op, dev1_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
|
||||
TFE_DeleteTensorHandle(var_handle0);
|
||||
TFE_DeleteTensorHandle(var_handle1);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) {
|
||||
TestRemoteExecuteUpdateServerDefResourceAccess(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) {
|
||||
TestRemoteExecuteUpdateServerDefResourceAccess(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
||||
// when updating cluster with non-exsitent worker
|
||||
@ -282,6 +401,7 @@ void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{2, tensorflow::strings::StrCat("localhost:", port)});
|
||||
server_def.set_task_index(0);
|
||||
string serialized_update = server_def.SerializeAsString();
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||
serialized_update.size(), status);
|
||||
|
@ -657,3 +657,17 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
||||
std::move(tensor_handles), context, &handle);
|
||||
return tensorflow::wrap(handle);
|
||||
}
|
||||
|
||||
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetAllowSoftPlacement(enable);
|
||||
}
|
||||
|
||||
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetLogDevicePlacement(enable);
|
||||
}
|
||||
|
@ -549,6 +549,18 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
|
||||
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
|
||||
TF_Status* status);
|
||||
|
||||
// Configure soft device placement policy for the eager executor. Note this
|
||||
// policy is applied to any subsequent op executions.
|
||||
TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status);
|
||||
|
||||
// Configure device placement policy logging for the eager executor. Note this
|
||||
// policy is applied to any subsequent op executions.
|
||||
TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -19,11 +19,16 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace {
|
||||
@ -434,7 +439,26 @@ string AddVariablesFunction() {
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestFunctionWithPackedInput) {
|
||||
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
TFE_TensorHandle* is_initialized[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
|
||||
CHECK_EQ(1, num_retvals);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
|
||||
bool initialized = false;
|
||||
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
EXPECT_EQ(initialized, true);
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(is_initialized[0]);
|
||||
TFE_DeleteOp(op);
|
||||
delete status;
|
||||
}
|
||||
|
||||
void TestFunctionWithPackedInput(const bool remote) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
@ -474,6 +498,12 @@ TEST(CAPI, TestFunctionWithPackedInput) {
|
||||
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
|
||||
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
|
||||
|
||||
// Add a sync point in order to make sure that variables have been initialized
|
||||
// before the function execution starts.
|
||||
// TODO(b/155789951): Remove once b/155789951 is fixed.
|
||||
VarIsInitialized(ctx, h1);
|
||||
VarIsInitialized(ctx, h2);
|
||||
|
||||
// Pack 3 variable handles into one TFE_TensorHandle.
|
||||
int num_replicas = 3;
|
||||
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
|
||||
@ -502,6 +532,10 @@ TEST(CAPI, TestFunctionWithPackedInput) {
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(func, packed_handle, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(func, task1_name, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
@ -537,6 +571,189 @@ TEST(CAPI, TestFunctionWithPackedInput) {
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestLocalFunctionWithPackedInput) {
|
||||
TestFunctionWithPackedInput(/*remote=*/false);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
|
||||
TestFunctionWithPackedInput(/*remote=*/true);
|
||||
}
|
||||
|
||||
string VariableAddFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'VariableAddFunction'"
|
||||
" input_arg {"
|
||||
" name: 'var0'"
|
||||
" type: DT_RESOURCE"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'var0_value'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read0'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add'"
|
||||
" op: 'Add'"
|
||||
" input: 'read0:value:0'"
|
||||
" input: 'read0:value:0'"
|
||||
" device: '/job:localhost/task:1/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'identity'"
|
||||
" op: 'Identity'"
|
||||
" input: 'add:z:0'"
|
||||
" device: '/job:localhost/task:0/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'var0_value'"
|
||||
" value: 'identity:output:0'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
|
||||
public:
|
||||
FunctionErrorInjectionPass(string error_node, string error_device)
|
||||
: error_node_(error_node), error_device_(error_device) {}
|
||||
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
|
||||
const tensorflow::ConfigProto& config_proto,
|
||||
std::unique_ptr<tensorflow::Graph>* graph,
|
||||
tensorflow::FunctionLibraryDefinition* flib_def,
|
||||
std::vector<std::string>* control_ret_node_names,
|
||||
bool* control_rets_updated) override {
|
||||
// Inject failure to function instantiation if finding a node that contains
|
||||
// the given node name (error_node_) and requested device (error_device_).
|
||||
for (const auto node : graph->get()->nodes()) {
|
||||
if (node->name().find(error_node_) != string::npos &&
|
||||
node->requested_device() == error_device_) {
|
||||
return tensorflow::errors::Internal("Injected graph pass error.");
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
const string error_node_;
|
||||
const string error_device_;
|
||||
};
|
||||
|
||||
void TestDistributedFunctionCancellation(bool inject_error) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
if (inject_error) {
|
||||
// Inject a function optimization pass failure when it sees the 'read0' op
|
||||
// having a requested device `dev2_name`. During execution:
|
||||
// * task:0 processes the main function `VariableAddFunction` and places
|
||||
// the read0 op on task:2
|
||||
// * task:0 partitions the main function with a subgraph containing read0
|
||||
// sent to task:2
|
||||
// * task:2 graph pass reports an error when it sees read0 with dev2_name
|
||||
tensorflow::function_optimization_registration::
|
||||
FunctionOptimizationPassRegistration register_test_pass(
|
||||
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
|
||||
}
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
|
||||
EXPECT_NE(var_handle, nullptr);
|
||||
|
||||
const string function_def = VariableAddFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(func, var_handle, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||
|
||||
if (inject_error) {
|
||||
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
|
||||
} else {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
float sum = 0;
|
||||
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
ASSERT_EQ(sum, 4.0);
|
||||
}
|
||||
|
||||
TFE_DeleteOp(func);
|
||||
TFE_DeleteTensorHandle(var_handle);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, DistributedFunctionNoError) {
|
||||
TestDistributedFunctionCancellation(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, DistributedFunctionCancelledOnError) {
|
||||
TestDistributedFunctionCancellation(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
|
@ -1203,6 +1203,8 @@ void BM_ReadVariable(int iters) {
|
||||
CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
h = nullptr;
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
TFE_DeleteOp(op);
|
||||
|
@ -150,6 +150,7 @@ TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &var_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(1, num_retvals);
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
@ -26,6 +28,51 @@ using tensorflow::string;
|
||||
using tensorflow::internal::OutputList;
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
typedef absl::flat_hash_map<std::string, FactoryFunction> FactoriesMap;
|
||||
|
||||
static FactoriesMap& GetFactories() {
|
||||
static FactoriesMap* factories = new FactoriesMap;
|
||||
return *factories;
|
||||
}
|
||||
|
||||
static const char* default_factory = "<unset>";
|
||||
|
||||
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
|
||||
assert((!GetFactories().count(name)) ||
|
||||
(GetFactories()[name] == factory) &&
|
||||
"Duplicate tracing factory registration");
|
||||
GetFactories()[name] = factory;
|
||||
}
|
||||
|
||||
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
|
||||
|
||||
static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
|
||||
TF_Status* s) {
|
||||
auto entry = GetFactories().find(default_factory);
|
||||
if (entry != GetFactories().end()) return entry->second(fn_name, s);
|
||||
string msg = absl::StrCat(
|
||||
"No tracing engine factory has been registered with the key '",
|
||||
default_factory, "' (available: ");
|
||||
// Ensure deterministic (sorted) order in the error message
|
||||
std::set<string> factories_sorted;
|
||||
for (const auto& factory : GetFactories())
|
||||
factories_sorted.insert(factory.first);
|
||||
const char* comma = "";
|
||||
for (const string& factory : factories_sorted) {
|
||||
msg += comma + factory;
|
||||
comma = ", ";
|
||||
}
|
||||
msg += ")";
|
||||
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace tensorflow
|
||||
|
||||
// =============================================================================
|
||||
// Public C API entry points
|
||||
//
|
||||
@ -36,6 +83,28 @@ using tensorflow::internal::unwrap;
|
||||
//
|
||||
// =============================================================================
|
||||
|
||||
void TF_SetTracingImplementation(const char* name) {
|
||||
tensorflow::internal::SetDefaultTracingEngine(name);
|
||||
}
|
||||
|
||||
// Creates a new TensorFlow function, it is an execution context attached to a
|
||||
// given tracing context.
|
||||
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
|
||||
return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s));
|
||||
}
|
||||
|
||||
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||
TF_OutputList* outputs, TF_Status* s) {
|
||||
auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s));
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
return func;
|
||||
}
|
||||
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s) {
|
||||
return wrap(unwrap(func)->AddParameter(dtype, s));
|
||||
}
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
||||
@ -58,6 +127,10 @@ int TF_OutputListNumOutputs(TF_OutputList* o) {
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
|
||||
return wrap(unwrap(o)->outputs[i]);
|
||||
}
|
||||
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
|
||||
TF_Status* s) {
|
||||
unwrap(o)->outputs.push_back(unwrap(tensor));
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
TF_Status* s) {
|
||||
|
@ -49,15 +49,26 @@ typedef struct TF_AbstractOp TF_AbstractOp;
|
||||
// setting functional attributes of other composite ops e.g. control flow.
|
||||
typedef struct TF_AbstractFunction TF_AbstractFunction;
|
||||
|
||||
// Creates a context for tracing the execution of operations into a function.
|
||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s);
|
||||
// This allows the client to swap the implementation of the tracing engine.
|
||||
// Any future call to TF_CreateFunction will use the implementation defined
|
||||
// here.
|
||||
void TF_SetTracingImplementation(const char* name);
|
||||
|
||||
// Creates a new TensorFlow function. A Function is an execution context, and as
|
||||
// such it can trace operations through TF_ExecuteOperation. After completing
|
||||
// tracing, a function can be obtained by TF_FinalizeFunction.
|
||||
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status);
|
||||
|
||||
// Creates a context for eager execution of operations.
|
||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
|
||||
TF_Status* s);
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||
|
||||
// Add a new parameter to a TensorFlow Function.
|
||||
// TODO(aminim): what about shape?
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s);
|
||||
|
||||
// Create an operation suitable to use with the provided context. The operation
|
||||
// requires its type (e.g. "AddV2") to be set independently.
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
|
||||
@ -77,19 +88,21 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||
|
||||
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
|
||||
// an operation.
|
||||
// It just lets us not specify the number of outputs of an operation
|
||||
// beforehand. This forces a memory allocation in the runtime, which is bad, but
|
||||
// it allows for generic code.
|
||||
// TODO(aminim): the description above isn't clear with respect to
|
||||
// TF_OutputListNumOutputs and the current eager implementation which requires
|
||||
// the number of outputs to be set by the client.
|
||||
// an operation, or provided to create a function.
|
||||
// When executing an operation in an eager context, the expected number of
|
||||
// outputs must be set beforehand with `TF_OutputListSetNumOutputs`.
|
||||
typedef struct TF_OutputList TF_OutputList;
|
||||
TF_OutputList* TF_NewOutputList();
|
||||
void TF_DeleteOutputList(TF_OutputList* o);
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
|
||||
// Prepare tracing to the expected number of output for an operation.
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*);
|
||||
// Return the number of outputs in the list.
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o);
|
||||
// Return the `i`th output in the list.
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
||||
// Append a tensor at the end of the output list, growing its size by one.
|
||||
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
|
||||
TF_Status*);
|
||||
|
||||
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
||||
// capture some inputs and then add a node in the graph. The output tensors are
|
||||
@ -100,13 +113,12 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_ExecutionContext* ctx, TF_Status* s);
|
||||
|
||||
// Creates a new TF_AbstractFunction from the current tracing states in the
|
||||
// context. The returned TF_GraphToFunction must be deleted by the client.
|
||||
// context. The provided `ctx` is consumed by this API call and deleted.
|
||||
// The returned TF_AbstractFunction must be deleted by the client,
|
||||
// TODO(aminim): clarify the contract on the state of the context after this
|
||||
// call.
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs, TF_Status* status);
|
||||
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||
TF_OutputList*, TF_Status*);
|
||||
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
|
||||
|
||||
|
@ -123,6 +123,17 @@ class EagerContext : public ExecutionContext {
|
||||
}
|
||||
}
|
||||
|
||||
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Can't add function parameter on an eager context.");
|
||||
return nullptr;
|
||||
}
|
||||
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Can't use finalize function on an eager context.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
|
||||
auto* func = afunc->GetTfFunction(s);
|
||||
if (!func) {
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
@ -114,12 +115,14 @@ struct GraphFunction : public AbstractFunction {
|
||||
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
||||
};
|
||||
|
||||
// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e.
|
||||
// adding them to the graph.
|
||||
// GraphContext wraps a TF_Graph modeling a single function and manages the
|
||||
// "execution" of operation, i.e. adding them to the function.
|
||||
class GraphContext : public ExecutionContext {
|
||||
public:
|
||||
GraphContext()
|
||||
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
|
||||
explicit GraphContext(const char* name)
|
||||
: ExecutionContext(kKind),
|
||||
graph_(new TF_Graph(), TF_DeleteGraph),
|
||||
name_(name) {}
|
||||
|
||||
AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
@ -136,6 +139,10 @@ class GraphContext : public ExecutionContext {
|
||||
return;
|
||||
}
|
||||
auto* tf_opdesc = graph_op->op_.release();
|
||||
if (tf_opdesc == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete.");
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
|
||||
if (!graph_tensor) {
|
||||
@ -164,24 +171,38 @@ class GraphContext : public ExecutionContext {
|
||||
}
|
||||
}
|
||||
|
||||
TF_Function* ToFunction(const char* fn_name, int num_inputs,
|
||||
const GraphTensor* inputs, int num_outputs,
|
||||
const GraphTensor* outputs, TF_Status* status) const {
|
||||
std::vector<TF_Output> graph_inputs;
|
||||
graph_inputs.resize(num_inputs);
|
||||
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
|
||||
TF_OperationDescription* opdesc =
|
||||
TF_NewOperation(graph_.get(), "Placeholder",
|
||||
absl::StrCat("_input_", inputs_.size()).c_str());
|
||||
TF_SetAttrType(opdesc, "dtype", dtype);
|
||||
auto* operation = TF_FinishOperation(opdesc, s);
|
||||
if (!s->status.ok()) return nullptr;
|
||||
|
||||
inputs_.push_back(TF_Output{operation, 0});
|
||||
return new GraphTensor(inputs_.back(), this);
|
||||
}
|
||||
|
||||
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
|
||||
std::unique_ptr<GraphFunction> func(new GraphFunction);
|
||||
std::vector<TF_Output> graph_outputs;
|
||||
graph_outputs.resize(num_outputs);
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
graph_inputs[i] = inputs[i].output;
|
||||
}
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
graph_outputs[i] = outputs[i].output;
|
||||
graph_outputs.reserve(outputs->outputs.size());
|
||||
for (AbstractTensor* abstract_output : outputs->outputs) {
|
||||
GraphTensor* output = dyncast<GraphTensor>(abstract_output);
|
||||
if (!output) {
|
||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||
"Returning a non-graph tensor from a function has not "
|
||||
"been implemented yet.");
|
||||
return nullptr;
|
||||
}
|
||||
graph_outputs.push_back(output->output);
|
||||
}
|
||||
|
||||
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
|
||||
graph_inputs.size(), graph_inputs.data(),
|
||||
graph_outputs.size(), graph_outputs.data(),
|
||||
nullptr, nullptr, fn_name, status);
|
||||
func->func = TF_GraphToFunction(
|
||||
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
|
||||
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
|
||||
if (TF_GetCode(s) != TF_OK) return nullptr;
|
||||
return func.release();
|
||||
}
|
||||
|
||||
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
||||
@ -195,54 +216,20 @@ class GraphContext : public ExecutionContext {
|
||||
|
||||
private:
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||
std::vector<TF_Output> inputs_;
|
||||
const char* name_;
|
||||
};
|
||||
|
||||
// Helper that converts the graph currently held in the context into a function.
|
||||
static AbstractFunction* ExecutionContextToFunction(
|
||||
const ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const AbstractTensor* inputs, int num_outputs,
|
||||
const AbstractTensor* outputs, TF_Status* status) {
|
||||
auto* graph_ctx = dyncast<const GraphContext>(fn_body);
|
||||
if (graph_ctx == nullptr) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"fn_body is not a TF_GraphContext.");
|
||||
return nullptr;
|
||||
}
|
||||
auto* graph_inputs = dyncast<const GraphTensor>(inputs);
|
||||
if (!graph_inputs) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
|
||||
return nullptr;
|
||||
}
|
||||
auto* graph_outputs = dyncast<const GraphTensor>(outputs);
|
||||
if (!graph_outputs) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
|
||||
return nullptr;
|
||||
}
|
||||
GraphFunction* func = new GraphFunction;
|
||||
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
|
||||
num_outputs, graph_outputs, status);
|
||||
return func;
|
||||
static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
||||
return new GraphContext(name);
|
||||
}
|
||||
|
||||
// Register the tracing implemented in this file as the default tracing engine.
|
||||
static bool register_tracing = [] {
|
||||
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
|
||||
SetDefaultTracingEngine("graphdef");
|
||||
return true;
|
||||
}();
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
// =============================================================================
|
||||
// Public C API entry points
|
||||
// These are only the entry points specific to the Graph API.
|
||||
// =============================================================================
|
||||
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
|
||||
return wrap(new tensorflow::internal::GraphContext());
|
||||
}
|
||||
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs, TF_Status* status) {
|
||||
return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs,
|
||||
unwrap(inputs), num_outputs,
|
||||
unwrap(outputs), status));
|
||||
}
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
@ -148,6 +149,17 @@ struct ExecutionContext {
|
||||
// Creates an empty AbstractOperation suitable to use with this context.
|
||||
virtual AbstractOp* CreateOperation() = 0;
|
||||
|
||||
// Add a function parameter and return the corresponding tensor.
|
||||
// This is only valid with an ExecutionContext obtained from a TracingContext,
|
||||
// it'll always error out with an eager context.
|
||||
virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0;
|
||||
|
||||
// Finalize this context and make a function out of it. The context is in a
|
||||
// invalid state after this call and must be destroyed.
|
||||
// This is only valid with an ExecutionContext obtained from a TracingContext,
|
||||
// it'll always error out with an eager context.
|
||||
virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0;
|
||||
|
||||
// Registers a functions with this context, after this the function is
|
||||
// available to be called/referenced by its name in this context.
|
||||
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
|
||||
@ -156,6 +168,11 @@ struct ExecutionContext {
|
||||
const ExecutionContextKind k;
|
||||
};
|
||||
|
||||
typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
|
||||
void SetDefaultTracingEngine(const char* name);
|
||||
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
|
||||
FactoryFunction factory);
|
||||
|
||||
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
|
||||
// C++ implementation, and back.
|
||||
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
|
||||
|
@ -29,7 +29,12 @@ using tensorflow::string;
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(UnifiedCAPI, TestBasicEager) {
|
||||
class UnifiedCAPI : public ::testing::TestWithParam<const char*> {
|
||||
protected:
|
||||
void SetUp() override { TF_SetTracingImplementation(GetParam()); }
|
||||
};
|
||||
|
||||
TEST_P(UnifiedCAPI, TestBasicEager) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -81,33 +86,18 @@ TEST(UnifiedCAPI, TestBasicEager) {
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TestBasicGraph) {
|
||||
TEST_P(UnifiedCAPI, TestBasicGraph) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
// Start a new function / execution context.
|
||||
string fn_name = "double";
|
||||
TF_ExecutionContext* graph_ctx =
|
||||
TF_CreateFunction(fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
auto* placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_OutputList* placeholder_outputs = TF_NewOutputList();
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
|
||||
graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
|
||||
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
|
||||
|
||||
// Delete placeholder op.
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
@ -123,17 +113,13 @@ TEST(UnifiedCAPI, TestBasicGraph) {
|
||||
// Execute.
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0);
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
|
||||
string fn_name = "double";
|
||||
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
||||
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
|
||||
TF_AbstractFunction* func =
|
||||
TF_FinalizeFunction(graph_ctx, add_outputs, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_DeleteAbstractTensor(placeholder_t);
|
||||
TF_DeleteAbstractTensor(output_t);
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -174,17 +160,160 @@ TEST(UnifiedCAPI, TestBasicGraph) {
|
||||
ASSERT_EQ(*f_value, 4.0);
|
||||
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
TF_DeleteOutputList(placeholder_outputs);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
TF_DeleteAbstractTensor(input_t);
|
||||
TF_DeleteAbstractTensor(final_result);
|
||||
TF_DeleteTensor(f_t);
|
||||
TF_DeleteAbstractFunction(func);
|
||||
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Status* s = status.get();
|
||||
|
||||
// Start a new function / execution context.
|
||||
string fn_name = "two_adds";
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Create a first "Add" computing `arg0 + arg1`.
|
||||
TF_AbstractTensor* add_output1;
|
||||
{
|
||||
// Build an abstract operation, inputs and output.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add1", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg0, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
// Trace the operation now (create a node in the graph).
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
// Extract the resulting tensor.
|
||||
add_output1 = TF_OutputListGet(add_outputs, 0);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
}
|
||||
|
||||
// Same with a second "Add" computing `arg1 + arg1`.
|
||||
TF_AbstractTensor* add_output2;
|
||||
{
|
||||
// Build an abstract operation, inputs and output.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractOpSetOpName(add_op, "my_add2", s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg1, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
// Trace the operation now (create a node in the graph).
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
// Extract the resulting tensor.
|
||||
add_output2 = TF_OutputListGet(add_outputs, 0);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
}
|
||||
|
||||
// Finalize the function by providing the returned values.
|
||||
TF_AbstractFunction* func;
|
||||
{
|
||||
// We want to return the output of both add operations, create a new list
|
||||
// and populate it.
|
||||
TF_OutputList* func_outputs = TF_NewOutputList();
|
||||
TF_OutputListPushBack(func_outputs, add_output1, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_OutputListPushBack(func_outputs, add_output2, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
func = TF_FinalizeFunction(graph_ctx, func_outputs, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteOutputList(func_outputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* We traced so far this function:
|
||||
*
|
||||
* def two_adds(a, b):
|
||||
* my_add1 = a + b
|
||||
* my_add2 = b + b
|
||||
* return my_add1, my_add2
|
||||
*
|
||||
* Now we will execute this function with an eager context:
|
||||
*
|
||||
* output1, output2 = two_adds(2.0, 3.0)
|
||||
*
|
||||
* and check that we got 5.0 and 6.0 as results.
|
||||
*/
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewEagerExecutionContext(opts, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Build the abstract op to run the function.
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
|
||||
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Build two abstract input tensors as function arguments.
|
||||
std::vector<TF_AbstractTensor*> func_args;
|
||||
{
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
input_eager = TestScalarTensorHandle(eager_ctx, 3.0f);
|
||||
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
}
|
||||
|
||||
TF_OutputList* func_outputs = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(func_outputs, 2, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
|
||||
eager_execution_ctx, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
|
||||
|
||||
ASSERT_EQ(2, TF_OutputListNumOutputs(func_outputs));
|
||||
float results[2];
|
||||
for (int idx = 0; idx < 2; ++idx) {
|
||||
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
|
||||
TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
results[idx] = *static_cast<float*>(TF_TensorData(f_t));
|
||||
TF_DeleteTensor(f_t);
|
||||
}
|
||||
ASSERT_EQ(results[0], 5.0);
|
||||
ASSERT_EQ(results[1], 6.0);
|
||||
|
||||
for (int idx = 0; idx < 2; ++idx) {
|
||||
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
|
||||
TF_DeleteAbstractTensor(result);
|
||||
}
|
||||
TF_DeleteOutputList(func_outputs);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
TF_DeleteAbstractFunction(func);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -193,18 +322,15 @@ TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
||||
ctx, nullptr, 0, nullptr, 0, nullptr, status.get());
|
||||
TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get());
|
||||
ASSERT_EQ(nullptr, func);
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
@ -222,10 +348,10 @@ TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
@ -243,7 +369,7 @@ TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
// Build an Eager context.
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -273,7 +399,8 @@ TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build a Graph context.
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute eager op using graph context.
|
||||
@ -289,10 +416,11 @@ TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
@ -349,5 +477,7 @@ TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef"));
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -101,6 +101,9 @@ class AbstractContextInterface {
|
||||
// Destroy the step resource container for a training step.
|
||||
virtual void EndStep() = 0;
|
||||
|
||||
// Block until all pending nodes are finished.
|
||||
virtual Status AsyncWait() = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
};
|
||||
|
@ -44,6 +44,7 @@ tf_cc_test(
|
||||
srcs = ["parallel_device_test.cc"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
@ -53,3 +54,19 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# Note: ParallelDevice-specific ops are experimental and not currently linked in
|
||||
# to TensorFlow by default, just used in a few tests.
|
||||
filegroup(
|
||||
name = "parallel_device_ops_srcs",
|
||||
srcs = ["parallel_device_ops.cc"],
|
||||
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device_ops",
|
||||
srcs = [":parallel_device_ops_srcs"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = ["//tensorflow/core:framework"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -92,6 +92,10 @@ class ParallelDevice {
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) const;
|
||||
|
||||
// A parallel tensor with scalar integers numbering component devices.
|
||||
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
||||
TF_Status* status) const;
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors (or
|
||||
@ -208,6 +212,46 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
status);
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
TFE_Context* context, TF_Status* status) const {
|
||||
// TODO(allenl): We could cache DeviceIDs (keyed by context).
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
int64_t* device_id = new int64_t;
|
||||
*device_id = device_index;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(
|
||||
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||
sizeof(int64_t),
|
||||
[](void* data, size_t, void* arg) {
|
||||
delete reinterpret_cast<int64_t*>(data);
|
||||
},
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
// TODO(allenl): Here and when executing regular operations, we could hold
|
||||
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
||||
// device names repeatedly.
|
||||
OpPtr const_op(TFE_NewOp(context, "Const", status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
||||
TFE_TensorHandle* device_handle;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(device_handle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
@ -282,6 +326,13 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
}
|
||||
result.emplace(std::move(outputs));
|
||||
return result;
|
||||
} else if (operation_name == std::string("DeviceID")) {
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(1);
|
||||
result_content.push_back(DeviceIDs(context, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
maybe_parallel_results(
|
||||
|
26
tensorflow/c/eager/parallel_device/parallel_device_ops.cc
Normal file
26
tensorflow/c/eager/parallel_device/parallel_device_ops.cc
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
// TODO(allenl): Figure out if we need this op, and if so whether we should move
|
||||
// it to core TF. Right now the eager C API does some checking of op
|
||||
// registrations before calling into custom devices, but we may be able to avoid
|
||||
// that.
|
||||
REGISTER_OP("DeviceID")
|
||||
.Output("device_id: int64")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
|
@ -278,14 +278,15 @@ TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
||||
}
|
||||
|
||||
// Assert that `handle` is equal to `expected_value`.
|
||||
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
|
||||
template <typename value_type>
|
||||
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(expected_value,
|
||||
*static_cast<float*>(TF_TensorData(value_zero.get())));
|
||||
EXPECT_EQ(expected_value,
|
||||
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
|
||||
}
|
||||
|
||||
template <std::size_t num_devices>
|
||||
@ -343,8 +344,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 20.);
|
||||
AssertScalarFloatEq(components[1].get(), 20.);
|
||||
ExpectScalarEq<float>(components[0].get(), 20.);
|
||||
ExpectScalarEq<float>(components[1].get(), 20.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
@ -373,8 +374,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 23.);
|
||||
AssertScalarFloatEq(components[1].get(), 18.);
|
||||
ExpectScalarEq<float>(components[0].get(), 23.);
|
||||
ExpectScalarEq<float>(components[1].get(), 18.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
@ -383,6 +384,32 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
// Compute the device ID twice and verify the result
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, result_handle, &components, status.get());
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<int64_t>(components[0].get(), 0);
|
||||
ExpectScalarEq<int64_t>(components[1].get(), 1);
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestBasicCPU) {
|
||||
@ -498,8 +525,8 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// The value of the original tensor is replicated on each device.
|
||||
AssertScalarFloatEq(components[0].get(), 3.);
|
||||
AssertScalarFloatEq(components[1].get(), 3.);
|
||||
ExpectScalarEq<float>(components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(components[1].get(), 3.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device =
|
||||
@ -630,7 +657,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
&second_components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(second_components[1].get(), 9.);
|
||||
ExpectScalarEq<float>(second_components[1].get(), 9.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
@ -644,8 +671,8 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
std::array<TensorHandlePtr, 2> first_components;
|
||||
ExtractPerDeviceValues(context.get(), second_components[0].get(),
|
||||
&first_components, status.get());
|
||||
AssertScalarFloatEq(first_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(first_components[1].get(), 6.);
|
||||
ExpectScalarEq<float>(first_components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(first_components[1].get(), 6.);
|
||||
|
||||
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
|
||||
status.get());
|
||||
@ -806,8 +833,8 @@ TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 3.);
|
||||
ExpectScalarEq<float>(result_components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(result_components[1].get(), 3.);
|
||||
}
|
||||
|
||||
void RegisterCollectiveMulFunction(TFE_Context* context,
|
||||
@ -909,8 +936,8 @@ TEST(PARALLEL_DEVICE, TestFunction) {
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
|
||||
ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
|
||||
ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
|
||||
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
result_components[0].get(), status.get());
|
||||
|
@ -31,9 +31,6 @@ cc_library(
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
# TODO(bmzhao): Remove this as we refactor C API to granular targets,
|
||||
# so that we can depend on c/eager/c_api_unified_experimental.h.
|
||||
features = ["-layering_check"],
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
@ -41,6 +38,8 @@ cc_library(
|
||||
":concrete_function_type",
|
||||
":function_metadata",
|
||||
":function_metadata_type",
|
||||
":tensorhandle_list",
|
||||
":tensorhandle_list_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
@ -160,6 +159,38 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle_list",
|
||||
srcs = [
|
||||
"tensorhandle_list.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_list_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle_list_type",
|
||||
hdrs = [
|
||||
"tensorhandle_list_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_api_test",
|
||||
size = "small",
|
||||
|
@ -15,12 +15,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
@ -29,10 +29,9 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
|
||||
&tensorflow::unwrap(func)->GetFunctionMetadata()));
|
||||
}
|
||||
|
||||
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
|
||||
// TODO(bmzhao): Refactor TF_OutputList struct definition into a separate
|
||||
// internal header, and implement this function.
|
||||
return nullptr;
|
||||
const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
|
||||
TF_ConcreteFunction* func) {
|
||||
return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures());
|
||||
}
|
||||
|
||||
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {
|
||||
|
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
size_t TF_TensorHandleListSize(const TF_TensorHandleList* list) {
|
||||
return tensorflow::unwrap(list)->size();
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list,
|
||||
int i) {
|
||||
return tensorflow::wrap((*tensorflow::unwrap(list))[i]);
|
||||
}
|
||||
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to
|
||||
// change and should not be depended on.
|
||||
|
||||
typedef struct TF_TensorHandleList TF_TensorHandleList;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*>,
|
||||
TF_TensorHandleList)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
@ -24,6 +24,7 @@ exports_files(
|
||||
"concrete_function_list.h",
|
||||
"function_metadata.h",
|
||||
"saved_model_api.h",
|
||||
"tensorhandle_list.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
|
||||
)
|
||||
@ -39,6 +40,7 @@ cc_library(
|
||||
":concrete_function_list",
|
||||
":function_metadata",
|
||||
":saved_model_api",
|
||||
":tensorhandle_list",
|
||||
],
|
||||
)
|
||||
|
||||
@ -61,3 +63,8 @@ alias(
|
||||
name = "saved_model_api",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "tensorhandle_list",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:tensorhandle_list",
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@ -36,7 +36,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
// Returns a list of TensorHandles implicitly captured by this function.
|
||||
TF_CAPI_EXPORT extern TF_OutputList* TF_ConcreteFunctionGetCaptures(
|
||||
TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
// Returns a TFE_Op suitable for executing this function.
|
||||
|
@ -21,19 +21,27 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
|
||||
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
|
||||
|
||||
// Returns the size of `list`.
|
||||
TF_CAPI_EXPORT size_t
|
||||
TF_ConcreteFunctionListSize(TF_ConcreteFunctionList* list);
|
||||
TF_CAPI_EXPORT extern size_t TF_ConcreteFunctionListSize(
|
||||
TF_ConcreteFunctionList* list);
|
||||
|
||||
// Returns the `i`th TF_ConcreteFunction in the list.
|
||||
TF_CAPI_EXPORT TF_ConcreteFunction* TF_ConcreteFunctionListGet(
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_ConcreteFunctionListGet(
|
||||
TF_ConcreteFunctionList* list, int i);
|
||||
|
||||
// Deletes `list`.
|
||||
TF_CAPI_EXPORT void TF_DeleteConcreteFunctionList(
|
||||
TF_CAPI_EXPORT extern void TF_DeleteConcreteFunctionList(
|
||||
TF_ConcreteFunctionList* list);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
|
@ -0,0 +1,43 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
|
||||
typedef struct TF_TensorHandleList TF_TensorHandleList;
|
||||
|
||||
// Returns the size of `list`.
|
||||
TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize(
|
||||
const TF_TensorHandleList* list);
|
||||
|
||||
// Returns the `i`th TFE_TensorHandle in the list.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet(
|
||||
const TF_TensorHandleList* list, int i);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
|
@ -62,3 +62,17 @@ cc_library(
|
||||
"//tensorflow/c:tf_tensor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle",
|
||||
hdrs = [
|
||||
"tensorhandle.h",
|
||||
],
|
||||
deps = [
|
||||
":runtime",
|
||||
":status",
|
||||
":tensor",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
],
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
|
||||
@ -40,6 +41,7 @@ class Runtime {
|
||||
private:
|
||||
friend class RuntimeBuilder;
|
||||
friend class SavedModelAPI;
|
||||
friend class TensorHandle;
|
||||
|
||||
// Wraps a TFE_Context. Takes ownership of ctx.
|
||||
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
|
||||
@ -63,6 +65,7 @@ class Runtime {
|
||||
};
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
|
||||
@ -79,6 +80,7 @@ inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Status is a wrapper around an error code and an optional error message.
|
||||
@ -57,6 +58,7 @@ class Status {
|
||||
friend class RuntimeBuilder;
|
||||
friend class Runtime;
|
||||
friend class SavedModelAPI;
|
||||
friend class TensorHandle;
|
||||
|
||||
// Wraps a TF_Status*, and takes ownership of it.
|
||||
explicit Status(TF_Status* status) : status_(status) {}
|
||||
@ -88,6 +90,7 @@ inline void Status::SetStatus(TF_Code code, const std::string& msg) {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Tensor represents an n-dimensional array of values.
|
||||
@ -168,6 +169,7 @@ inline Tensor Tensor::FromBuffer(TF_DataType dtype,
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
|
||||
|
98
tensorflow/cc/experimental/base/public/tensorhandle.h
Normal file
98
tensorflow/cc/experimental/base/public/tensorhandle.h
Normal file
@ -0,0 +1,98 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// An opaque representation of a tensor computed/managed by the Tensorflow
|
||||
// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer
|
||||
// to tensors placed in memory of different devices or remote address spaces.
|
||||
// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created
|
||||
// from it.
|
||||
class TensorHandle {
|
||||
public:
|
||||
// Unwraps a Tensor from the given TensorHandle. If an error occurred,
|
||||
// status->ok() will be false, and the returned Tensor must not be used.
|
||||
Tensor Resolve(Status* status);
|
||||
|
||||
// Constructs a TensorHandle from a Tensor. If an error occurred,
|
||||
// status->ok() will be false, and the returned TensorHandle must not be used.
|
||||
static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime,
|
||||
Status* status);
|
||||
|
||||
// TensorHandle is movable, and not copyable
|
||||
TensorHandle(TensorHandle&&) = default;
|
||||
TensorHandle& operator=(TensorHandle&&) = default;
|
||||
|
||||
private:
|
||||
// Wraps a TFE_TensorHandle. Takes ownership of handle.
|
||||
explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {}
|
||||
|
||||
// TensorHandle is not copyable
|
||||
TensorHandle(const TensorHandle&) = delete;
|
||||
TensorHandle& operator=(const TensorHandle&) = delete;
|
||||
|
||||
// Returns the underlying TFE_TensorHandle that this object wraps.
|
||||
// This object retains ownership of the pointer.
|
||||
TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); }
|
||||
|
||||
// Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle,
|
||||
// and takes ownership of handle.
|
||||
void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); }
|
||||
|
||||
struct TFETensorHandleDeleter {
|
||||
void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
|
||||
};
|
||||
std::unique_ptr<TFE_TensorHandle, TFETensorHandleDeleter> handle_;
|
||||
};
|
||||
|
||||
inline Tensor TensorHandle::Resolve(Status* status) {
|
||||
TF_Tensor* tensor =
|
||||
TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return Tensor(nullptr);
|
||||
}
|
||||
return Tensor(tensor);
|
||||
}
|
||||
|
||||
inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor,
|
||||
const Runtime& runtime,
|
||||
Status* status) {
|
||||
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor(
|
||||
runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return TensorHandle(nullptr);
|
||||
}
|
||||
return TensorHandle(tensor_handle);
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
@ -5,12 +5,22 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_types_test_util",
|
||||
testonly = True,
|
||||
hdrs = ["tensor_types_test_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_datatype",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tensor_test",
|
||||
srcs = [
|
||||
"tensor_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":tensor_types_test_util",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
"//tensorflow/cc/experimental/base/public:tensor",
|
||||
@ -19,3 +29,22 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tensorhandle_test",
|
||||
srcs = [
|
||||
"tensorhandle_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":tensor_types_test_util",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/cc/experimental/base/public:runtime",
|
||||
"//tensorflow/cc/experimental/base/public:runtime_builder",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
"//tensorflow/cc/experimental/base/public:tensor",
|
||||
"//tensorflow/cc/experimental/base/public:tensorhandle",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
@ -16,69 +16,22 @@ limitations under the License.
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Each of the following struct types have two members: a kDType that
|
||||
// corresponds to a TF_Datatype enum value, and a typedef "type"
|
||||
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
|
||||
// tests via GoogleTest's TypedTests:
|
||||
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
|
||||
struct FloatType {
|
||||
using type = float;
|
||||
static constexpr TF_DataType kDType = TF_FLOAT;
|
||||
};
|
||||
using tensorflow::experimental::cc::Status;
|
||||
using tensorflow::experimental::cc::Tensor;
|
||||
|
||||
struct DoubleType {
|
||||
using type = double;
|
||||
static constexpr TF_DataType kDType = TF_DOUBLE;
|
||||
};
|
||||
|
||||
struct Int32Type {
|
||||
using type = int32_t;
|
||||
static constexpr TF_DataType kDType = TF_INT32;
|
||||
};
|
||||
|
||||
struct UINT8Type {
|
||||
using type = uint8_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT8;
|
||||
};
|
||||
|
||||
struct INT8Type {
|
||||
using type = int8_t;
|
||||
static constexpr TF_DataType kDType = TF_INT8;
|
||||
};
|
||||
|
||||
struct INT64Type {
|
||||
using type = int64_t;
|
||||
static constexpr TF_DataType kDType = TF_INT64;
|
||||
};
|
||||
|
||||
struct UINT16Type {
|
||||
using type = uint16_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT16;
|
||||
};
|
||||
|
||||
struct UINT32Type {
|
||||
using type = uint32_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT32;
|
||||
};
|
||||
|
||||
struct UINT64Type {
|
||||
using type = uint64_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT64;
|
||||
};
|
||||
|
||||
using SimpleTypes =
|
||||
::testing::Types<FloatType, DoubleType, Int32Type, UINT8Type, INT8Type,
|
||||
INT64Type, UINT16Type, UINT32Type, UINT64Type>;
|
||||
using SimpleTypes = ::testing::Types<
|
||||
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
|
||||
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
|
||||
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
|
||||
|
||||
template <typename T>
|
||||
class ConstructScalarTensorTest : public ::testing::Test {};
|
||||
@ -88,14 +41,13 @@ TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes);
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
cc::Status status;
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
typename TypeParam::type value = 42;
|
||||
cc::Tensor tensor =
|
||||
cc::Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||
/*data=*/&value,
|
||||
/*len=*/sizeof(value),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||
/*data=*/&value,
|
||||
/*len=*/sizeof(value),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 0);
|
||||
@ -113,7 +65,7 @@ TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes);
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
cc::Status status;
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
@ -121,7 +73,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(value.size());
|
||||
|
||||
cc::Tensor tensor = cc::Tensor::FromBuffer(
|
||||
Tensor tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
@ -130,7 +82,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 1);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
@ -152,14 +104,14 @@ TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes);
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
cc::Status status;
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 2 vector with shape 2 x 3.
|
||||
std::vector<int64_t> shape({2, 3});
|
||||
|
||||
cc::Tensor tensor = cc::Tensor::FromBuffer(
|
||||
Tensor tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
@ -169,7 +121,7 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 2);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
@ -185,22 +137,22 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
|
||||
TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
|
||||
bool done = false;
|
||||
cc::Status status;
|
||||
Status status;
|
||||
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
|
||||
{
|
||||
// data_vector is a rank 1 tensor.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(data_vector.size());
|
||||
|
||||
cc::Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
|
||||
Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
|
||||
done = true;
|
||||
};
|
||||
|
||||
cc::Tensor tensor =
|
||||
cc::Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
|
||||
/*data=*/data_vector.data(),
|
||||
/*len=*/data_vector.size() * sizeof(int32_t),
|
||||
/*deleter=*/callback, &status);
|
||||
Tensor tensor =
|
||||
Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
|
||||
/*data=*/data_vector.data(),
|
||||
/*len=*/data_vector.size() * sizeof(int32_t),
|
||||
/*deleter=*/callback, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
}
|
||||
// At this point, tensor has been destroyed, and the deleter callback should
|
||||
@ -209,4 +161,3 @@ TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Each of the following struct types have two members: a kDType that
|
||||
// corresponds to a TF_Datatype enum value, and a typedef "type"
|
||||
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
|
||||
// tests via GoogleTest's TypedTests:
|
||||
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
|
||||
struct FloatType {
|
||||
using type = float;
|
||||
static constexpr TF_DataType kDType = TF_FLOAT;
|
||||
};
|
||||
|
||||
struct DoubleType {
|
||||
using type = double;
|
||||
static constexpr TF_DataType kDType = TF_DOUBLE;
|
||||
};
|
||||
|
||||
struct Int32Type {
|
||||
using type = int32_t;
|
||||
static constexpr TF_DataType kDType = TF_INT32;
|
||||
};
|
||||
|
||||
struct UINT8Type {
|
||||
using type = uint8_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT8;
|
||||
};
|
||||
|
||||
struct INT8Type {
|
||||
using type = int8_t;
|
||||
static constexpr TF_DataType kDType = TF_INT8;
|
||||
};
|
||||
|
||||
struct INT64Type {
|
||||
using type = int64_t;
|
||||
static constexpr TF_DataType kDType = TF_INT64;
|
||||
};
|
||||
|
||||
struct UINT16Type {
|
||||
using type = uint16_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT16;
|
||||
};
|
||||
|
||||
struct UINT32Type {
|
||||
using type = uint32_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT32;
|
||||
};
|
||||
|
||||
struct UINT64Type {
|
||||
using type = uint64_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT64;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
184
tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
Normal file
184
tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
Normal file
@ -0,0 +1,184 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/experimental/base/public/tensorhandle.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using tensorflow::experimental::cc::Runtime;
|
||||
using tensorflow::experimental::cc::RuntimeBuilder;
|
||||
using tensorflow::experimental::cc::Status;
|
||||
using tensorflow::experimental::cc::Tensor;
|
||||
using tensorflow::experimental::cc::TensorHandle;
|
||||
|
||||
using SimpleTypes = ::testing::Types<
|
||||
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
|
||||
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
|
||||
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
|
||||
|
||||
template <typename T>
|
||||
class ConstructScalarTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
|
||||
// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
|
||||
// verify the expected dims, dtype, value, num bytes, and num elements.
|
||||
TYPED_TEST(ConstructScalarTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
typename TypeParam::type value = 42;
|
||||
Tensor original_tensor =
|
||||
Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||
/*data=*/&value,
|
||||
/*len=*/sizeof(value),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 0);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
|
||||
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct1DTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct1DTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 1 vector.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(value.size());
|
||||
|
||||
Tensor original_tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 1);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct2DTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct2DTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 2 vector with shape 2 x 3.
|
||||
std::vector<int64_t> shape({2, 3});
|
||||
|
||||
Tensor original_tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 2);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -84,7 +84,7 @@ cc_library(
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
]) + if_android([
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
]),
|
||||
)
|
||||
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
|
||||
@ -54,6 +55,7 @@ inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// ConcreteFunctionList helps convert an opaque pointer to an array of
|
||||
@ -56,6 +57,7 @@ inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// FunctionMetadata stores additional function information, including
|
||||
@ -40,6 +41,7 @@ class FunctionMetadata final {
|
||||
};
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// SavedModelAPI offers a way to load Tensorflow Saved Models
|
||||
@ -155,6 +156,7 @@ inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
|
||||
|
@ -26,10 +26,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
using tensorflow::experimental::cc::Runtime;
|
||||
using tensorflow::experimental::cc::RuntimeBuilder;
|
||||
using tensorflow::experimental::cc::SavedModelAPI;
|
||||
using tensorflow::experimental::cc::Status;
|
||||
|
||||
constexpr char kTestData[] = "cc/saved_model/testdata";
|
||||
|
||||
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||
@ -43,21 +47,21 @@ std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
|
||||
|
||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||
cc::Status status;
|
||||
cc::RuntimeBuilder builder;
|
||||
Status status;
|
||||
RuntimeBuilder builder;
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
builder.SetUseTFRT(use_tfrt);
|
||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
||||
std::unique_ptr<Runtime> runtime = builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
std::unordered_set<std::string> tags = {"serve"};
|
||||
std::unique_ptr<cc::SavedModelAPI> model =
|
||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
|
||||
std::unique_ptr<SavedModelAPI> model =
|
||||
SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
@ -67,20 +71,20 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||
}
|
||||
|
||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
|
||||
cc::Status status;
|
||||
cc::RuntimeBuilder builder;
|
||||
Status status;
|
||||
RuntimeBuilder builder;
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
builder.SetUseTFRT(use_tfrt);
|
||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
||||
std::unique_ptr<Runtime> runtime = builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
std::unique_ptr<cc::SavedModelAPI> model =
|
||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status);
|
||||
std::unique_ptr<SavedModelAPI> model =
|
||||
SavedModelAPI::Load(model_dir, *runtime, &status);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
@ -94,4 +98,3 @@ INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -42,7 +42,8 @@ def tf_library(
|
||||
mlir_components = "None",
|
||||
deps = None,
|
||||
tags = []):
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code with fast
|
||||
math enabled on cpu.
|
||||
|
||||
Given an invocation of tf_library(name="foo", ...), generates the following
|
||||
build targets:
|
||||
@ -207,6 +208,15 @@ def tf_library(
|
||||
srcs.append(debug_info)
|
||||
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
|
||||
|
||||
default_fast_math_xla_flags = ("XLA_FLAGS='" +
|
||||
"--xla_cpu_enable_fast_math=true " +
|
||||
"--xla_cpu_fast_math_honor_nans=false " +
|
||||
"--xla_cpu_fast_math_honor_infs=false " +
|
||||
"--xla_cpu_fast_math_honor_functions=false " +
|
||||
"--xla_cpu_fast_math_honor_division=false " +
|
||||
"--xla_cpu_enable_fast_min_max=true " +
|
||||
"$${XLA_FLAGS:-}' ")
|
||||
|
||||
native.genrule(
|
||||
name = ("gen_" + name),
|
||||
srcs = srcs,
|
||||
@ -216,6 +226,7 @@ def tf_library(
|
||||
function_object_file,
|
||||
],
|
||||
cmd = (
|
||||
default_fast_math_xla_flags +
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
@ -256,6 +267,7 @@ def tf_library(
|
||||
session_module_pb,
|
||||
],
|
||||
cmd = (
|
||||
default_fast_math_xla_flags +
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
|
@ -67,6 +67,8 @@ int main(int argc, char** argv) {
|
||||
flags.entry_point = "entry";
|
||||
flags.debug_info_path_begin_marker = "";
|
||||
|
||||
// Note that tfcompile.bzl's tf_library macro sets fast math flags as that is
|
||||
// generally the preferred case.
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
AppendMainFlags(&flag_list, &flags);
|
||||
xla::AppendDebugOptionsFlags(&flag_list);
|
||||
|
@ -251,7 +251,7 @@ cc_library(
|
||||
visibility = [":friends"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:graph",
|
||||
|
@ -77,10 +77,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
||||
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
|
||||
"//tensorflow/compiler/mlir/tfrt:lower_tf_to_tfd_alwayslink",
|
||||
"//tensorflow/compiler/mlir/tfrt:runtime_fallback_opdefs_alwayslink",
|
||||
"//tensorflow/compiler/mlir/tfrt:tf_legalize_to_tfrt",
|
||||
"//tensorflow/compiler/mlir/tfrt:tf_to_corert",
|
||||
],
|
||||
)
|
||||
|
||||
@ -152,7 +148,6 @@ tf_cc_binary(
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op",
|
||||
"//tensorflow/compiler/mlir/tfrt:compatibility_analysis",
|
||||
"//tensorflow/compiler/mlir/xla:xla_mlir_translate",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -31,7 +31,7 @@ filegroup(
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -799,11 +799,6 @@ Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
|
||||
|
||||
Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
|
||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
|
||||
std::string node_def_str;
|
||||
if (!node_def.SerializeToString(&node_def_str)) {
|
||||
return emitError(loc, "failed to serialize tensorflow node_def"),
|
||||
llvm::None;
|
||||
}
|
||||
auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
|
||||
return builder_.CreateVector(flex_builder->GetBuffer());
|
||||
}
|
||||
@ -813,9 +808,13 @@ Translator::CreateFlexBuilderWithNodeAttrs(
|
||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
|
||||
auto flex_builder = absl::make_unique<flexbuffers::Builder>();
|
||||
size_t map_start = flex_builder->StartMap();
|
||||
for (const auto& pair : node_def.attr()) {
|
||||
using Item = std::pair<std::string, ::tensorflow::AttrValue>;
|
||||
std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
|
||||
std::sort(attrs.begin(), attrs.end(),
|
||||
[](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
|
||||
for (const Item& pair : attrs) {
|
||||
const char* key = pair.first.c_str();
|
||||
const auto& attr = pair.second;
|
||||
const ::tensorflow::AttrValue& attr = pair.second;
|
||||
switch (attr.value_case()) {
|
||||
case ::tensorflow::AttrValue::kS:
|
||||
flex_builder->String(key, attr.s());
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
|
||||
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
|
||||
|
||||
@ -414,9 +414,9 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, QI8, QUI8]>:$input,
|
||||
ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
|
||||
TFL_TensorOf<[F32, QI8, QUI8]>:$filter,
|
||||
TFL_TensorOfOrNone<[F32, I32]>:$bias,
|
||||
TFL_TensorOfOrNone<[F32, I32, I64]>:$bias,
|
||||
I32Attr:$dilation_h_factor,
|
||||
I32Attr:$dilation_w_factor,
|
||||
TFL_AFAttr:$fused_activation_function,
|
||||
@ -425,7 +425,7 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
I32Attr:$stride_w
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
|
||||
|
||||
let hasOptions = 0b1;
|
||||
}
|
||||
@ -1561,10 +1561,12 @@ def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
|
||||
}
|
||||
|
||||
def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
TFL_GpuTargetOp]> {
|
||||
def TFL_HardSwishOp: TFL_Op<"hard_swish", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Hardswish activation function.";
|
||||
let description = [{
|
||||
Computes hard-swish activation function
|
||||
@ -1574,7 +1576,7 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect,
|
||||
|
||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$input);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$out);
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$output);
|
||||
|
||||
let hasOptions = 0;
|
||||
}
|
||||
@ -1606,7 +1608,8 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect,
|
||||
def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
|
||||
SameOperandsAndResultShape,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultType]> {
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
|
||||
let summary = "Leaky Relu operator";
|
||||
|
||||
let description = [{
|
||||
@ -1740,7 +1743,8 @@ def TFL_LogOp: TFL_Op<"log", [
|
||||
def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
PredOpTrait<"x and y must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
// zero_point = max_value
|
||||
// scale = -log_softmax_output_min / (max_value + 1)
|
||||
FixedResultScale<Int8UniformQuantizedType<127, 625, -4>>,
|
||||
@ -1896,11 +1900,11 @@ Rounds the values of a tensor to the nearest integer, element-wise.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32]>:$x
|
||||
TFL_FpTensor:$x
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32]>:$y
|
||||
TFL_FpTensor:$y
|
||||
);
|
||||
}
|
||||
|
||||
@ -2443,9 +2447,9 @@ def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
|
||||
Computes element-wise reverse square root of input
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$x);
|
||||
let arguments = (ins TFL_FpTensor:$x);
|
||||
|
||||
let results = (outs AnyTensor:$y);
|
||||
let results = (outs TFL_FpTensor:$y);
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
@ -3361,9 +3365,11 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [
|
||||
let results = (outs AnyTensor:$output);
|
||||
}
|
||||
|
||||
def TFL_DensifyOp: TFL_Op<"densify", [NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
NoQuantizableResult]> {
|
||||
def TFL_DensifyOp: TFL_Op<"densify", [
|
||||
NoSideEffect,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Densify operator";
|
||||
|
||||
let description = [{
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lite {
|
||||
@ -38,6 +39,7 @@ namespace lite {
|
||||
TfLiteStatus QuantizeModel(
|
||||
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
|
||||
const tflite::TensorType& output_type,
|
||||
const tflite::TensorType& inference_type,
|
||||
const std::unordered_set<std::string>& operator_names,
|
||||
bool disable_per_channel, bool fully_quantize,
|
||||
flatbuffers::FlatBufferBuilder* builder,
|
||||
@ -73,7 +75,7 @@ TfLiteStatus QuantizeModel(
|
||||
// Apply quantization passes
|
||||
PassManager pm(module->getContext());
|
||||
TFL::QuantizationSpecs quant_specs;
|
||||
quant_specs.inference_type = tensorflow::DT_QINT8;
|
||||
quant_specs.inference_type = tflite::TflTypeToTfType(inference_type);
|
||||
quant_specs.post_training_quantization = true;
|
||||
quant_specs.disable_per_channel = disable_per_channel;
|
||||
|
||||
@ -81,8 +83,10 @@ TfLiteStatus QuantizeModel(
|
||||
auto input_tf_type = tflite::TflTypeToTfType(input_type);
|
||||
if (input_tf_type == tensorflow::DT_FLOAT) {
|
||||
emit_adaptor = true;
|
||||
} else if (input_tf_type == tensorflow::DT_UINT8) {
|
||||
quant_specs.inference_type = tensorflow::DT_QUINT8;
|
||||
} else if (input_tf_type == tensorflow::DT_UINT8 ||
|
||||
input_tf_type == tensorflow::DT_INT8 ||
|
||||
input_tf_type == tensorflow::DT_INT16) {
|
||||
quant_specs.inference_type = input_tf_type;
|
||||
}
|
||||
|
||||
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
|
||||
|
@ -26,11 +26,13 @@ namespace mlir {
|
||||
namespace lite {
|
||||
|
||||
// Quantize the `input_model` and write the result to a flatbuffer `builder`.
|
||||
// The `input_type` and `output_type` can be float32/qint8/int8.
|
||||
// The `input_type`, `output_type` and `inference_type` can be
|
||||
// float32/qint8/int8/int16.
|
||||
// Return partially quantized model if `fully_quantize` is false.
|
||||
TfLiteStatus QuantizeModel(
|
||||
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
|
||||
const tflite::TensorType& output_type,
|
||||
const tflite::TensorType& inference_type,
|
||||
const std::unordered_set<std::string>& operator_names,
|
||||
bool disable_per_channel, bool fully_quantize,
|
||||
flatbuffers::FlatBufferBuilder* builder,
|
||||
|
@ -46,7 +46,8 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer,
|
||||
|
||||
tflite::StderrReporter error_reporter;
|
||||
return mlir::lite::QuantizeModel(
|
||||
*model, tflite::TensorType_INT8, tflite::TensorType_INT8, {},
|
||||
*model, tflite::TensorType_INT8, tflite::TensorType_INT8,
|
||||
tflite::TensorType_INT8, {},
|
||||
/*disable_per_channel=*/false,
|
||||
/*fully_quantize=*/true, builder, &error_reporter);
|
||||
}
|
||||
|
@ -90,7 +90,7 @@ struct QuantizationSpecs {
|
||||
bool RunWeightQuantization() const { return weight_quantization; }
|
||||
|
||||
// Whether this inference type represents a signed storage type.
|
||||
bool IsSignedInferenceType() {
|
||||
bool IsSignedInferenceType() const {
|
||||
switch (inference_type) {
|
||||
case tensorflow::DT_QUINT8:
|
||||
case tensorflow::DT_QUINT16:
|
||||
@ -102,7 +102,7 @@ struct QuantizationSpecs {
|
||||
|
||||
// Gets the width of this quantization type. Returns 0 if it isn't a
|
||||
// quantization type.
|
||||
int64_t GetQuantizationTypeWidth() {
|
||||
int64_t GetQuantizationTypeWidth() const {
|
||||
switch (inference_type) {
|
||||
case tensorflow::DT_QINT8:
|
||||
case tensorflow::DT_QUINT8:
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
@ -35,6 +36,7 @@ limitations under the License.
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -363,6 +365,54 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
}
|
||||
};
|
||||
|
||||
// Fold Extra Requantize ops if the preceding ops has free scale requirement.
|
||||
template <typename RQ>
|
||||
struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
|
||||
explicit FoldTrivalRequantizeOp(MLIRContext* context)
|
||||
: OpRewritePattern<RQ>(context, 1) {}
|
||||
|
||||
LogicalResult matchAndRewrite(RQ op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Value pre_quantized = op.input();
|
||||
auto pre_quantized_type =
|
||||
quant::QuantizedType::getQuantizedElementType(pre_quantized.getType());
|
||||
if (!pre_quantized_type) return failure();
|
||||
|
||||
Operation* def = pre_quantized.getDefiningOp();
|
||||
if (!def) return failure();
|
||||
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() ||
|
||||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
op.emitWarning("Remove trivial `rescale` op. Please fix the source graph.");
|
||||
|
||||
llvm::SmallVector<Type, 4> new_output_types;
|
||||
for (auto result : def->getResults()) {
|
||||
result.getUsers().begin()->dump();
|
||||
op.dump();
|
||||
if (result.hasOneUse() && *result.getUsers().begin() == op) {
|
||||
new_output_types.push_back(op.qtype());
|
||||
} else {
|
||||
new_output_types.push_back(result.getType());
|
||||
}
|
||||
}
|
||||
|
||||
// Remove this rescale op.
|
||||
rewriter.replaceOp(op, {pre_quantized});
|
||||
|
||||
// Replace the output scale of the preceding op.
|
||||
rewriter.setInsertionPointAfter(def);
|
||||
OperationState new_state(def->getLoc(), def->getName().getStringRef(),
|
||||
def->getOperands(), new_output_types,
|
||||
def->getAttrs());
|
||||
Operation* new_op = rewriter.createOperation(new_state);
|
||||
|
||||
rewriter.replaceOp(def, new_op->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Given a quantized type `input`, magnifying its scales by the factor stored in
|
||||
// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
|
||||
// dimension size of `input` or isn't floating-point, nullptr will be returned.
|
||||
|
@ -11,9 +11,9 @@ func @reshape_removeAdjacent(tensor<4x4x4xf32>) -> tensor<64xf32> {
|
||||
return %1 : tensor<64xf32>
|
||||
|
||||
// CHECK-LABEL: func @reshape_removeAdjacent
|
||||
// CHECK: %cst = constant dense<64> : tensor<1xi32>
|
||||
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: return
|
||||
// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32>
|
||||
// CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: return %[[RESHAPE]]
|
||||
}
|
||||
|
||||
// Checks that tfl.reshape should be removed if its output has more than one
|
||||
@ -29,11 +29,11 @@ func @reshape_removeAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> tensor<64xf32>
|
||||
return %3 : tensor<64xf32>
|
||||
|
||||
// CHECK-LABEL: func @reshape_removeAdjacentWithMultipleUse
|
||||
// CHECK: %cst = constant dense<64> : tensor<1xi32>
|
||||
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: %1 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: %2 = addf %0, %1
|
||||
// CHECK: return %2
|
||||
// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32>
|
||||
// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: %[[RESULT:.*]] = addf %[[RESHAPE_1]], %[[RESHAPE_2]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
}
|
||||
|
||||
// Checks that tfl.reshape should be kept if its output has more than one
|
||||
@ -47,11 +47,11 @@ func @reshape_keepAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> (tensor<16x4xf32
|
||||
return %0, %1 : tensor<16x4xf32>, tensor<64xf32>
|
||||
|
||||
// CHECK-LABEL: func @reshape_keepAdjacentWithMultipleUse
|
||||
// CHECK: %cst = constant dense<[16, 4]> : tensor<2xi32>
|
||||
// CHECK: %cst_0 = constant dense<64> : tensor<1xi32>
|
||||
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32>
|
||||
// CHECK: %1 = "tfl.reshape"(%arg0, %cst_0) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: return %0, %1
|
||||
// CHECK: %[[CST:.*]] = constant dense<[16, 4]> : tensor<2xi32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<64> : tensor<1xi32>
|
||||
// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32>
|
||||
// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST_0]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: return %[[RESHAPE_1]], %[[RESHAPE_2]]
|
||||
}
|
||||
|
||||
// Checks that tfl.reshape should be removed if its output type is the same
|
||||
|
@ -8,13 +8,13 @@ func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>,
|
||||
%2 = constant dense< 3.5> : tensor<4xf32>
|
||||
%3 = constant dense<-0.5> : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<3.500000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_0 = constant dense<-5.000000e-01> : tensor<4xf32>
|
||||
// CHECK: %cst_1 = constant dense<6.000000e+00> : tensor<f32>
|
||||
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_3 = constant dense<5.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_4 = constant dense<3.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %0 = tfl.add %cst, %cst_0 {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<3.500000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<-5.000000e-01> : tensor<4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<6.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_3:.*]] = constant dense<5.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_4:.*]] = constant dense<3.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %0 = tfl.add %[[CST]], %[[CST_0]] {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
|
||||
|
||||
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
|
||||
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
@ -33,10 +33,10 @@ func @add_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
%2 = constant dense< 4> : tensor<4xi32>
|
||||
%3 = constant dense<-2> : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<9> : tensor<i32>
|
||||
// CHECK: %cst_0 = constant dense<6> : tensor<4xi32>
|
||||
// CHECK: %cst_1 = constant dense<5> : tensor<4xi32>
|
||||
// CHECK: %cst_2 = constant dense<2> : tensor<4xi32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<9> : tensor<i32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<6> : tensor<4xi32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<5> : tensor<4xi32>
|
||||
// CHECK: %[[CST_2:.*]] = constant dense<2> : tensor<4xi32>
|
||||
|
||||
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
|
||||
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
@ -54,10 +54,10 @@ func @sub_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
|
||||
%2 = constant dense< 3.5> : tensor<4xf32>
|
||||
%3 = constant dense<-0.5> : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<3.000000e+00> : tensor<f32>
|
||||
// CHECK: %cst_0 = constant dense<5.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_1 = constant dense<2.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<3.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<5.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<2.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32>
|
||||
|
||||
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
|
||||
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
@ -75,10 +75,10 @@ func @sub_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
%2 = constant dense< 4> : tensor<4xi32>
|
||||
%3 = constant dense<-2> : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<7> : tensor<i32>
|
||||
// CHECK: %cst_0 = constant dense<10> : tensor<4xi32>
|
||||
// CHECK: %cst_1 = constant dense<3> : tensor<4xi32>
|
||||
// CHECK: %cst_2 = constant dense<6> : tensor<4xi32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<7> : tensor<i32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<10> : tensor<4xi32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<3> : tensor<4xi32>
|
||||
// CHECK: %[[CST_2:.*]] = constant dense<6> : tensor<4xi32>
|
||||
|
||||
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
|
||||
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
@ -96,10 +96,10 @@ func @mul_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
|
||||
%2 = constant dense< 3.5> : tensor<4xf32>
|
||||
%3 = constant dense<-0.5> : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<6.750000e+00> : tensor<f32>
|
||||
// CHECK: %cst_0 = constant dense<-2.250000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_1 = constant dense<5.250000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_2 = constant dense<-1.750000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<6.750000e+00> : tensor<f32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<-2.250000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<5.250000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_2:.*]] = constant dense<-1.750000e+00> : tensor<4xf32>
|
||||
|
||||
%5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
|
||||
%6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
@ -170,8 +170,8 @@ func @add_dense_splat_int() -> tensor<4xi32> {
|
||||
|
||||
return %2 : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_splat_dense_int
|
||||
@ -183,8 +183,8 @@ func @add_splat_dense_int() -> tensor<4xi32> {
|
||||
|
||||
return %2 : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_int_same_shape
|
||||
@ -196,8 +196,8 @@ func @add_dense_dense_int_same_shape() -> tensor<4xi32> {
|
||||
|
||||
return %2 : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_int_trailing_dim
|
||||
@ -212,10 +212,10 @@ func @add_dense_dense_int_trailing_dim() -> (tensor<2x2xi32>, tensor<2x2x2xi32>,
|
||||
|
||||
return %0, %1, %2 : tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
|
||||
// CHECK: %cst_0 = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
|
||||
// CHECK: %cst_1 = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
|
||||
// CHECK: return %cst, %cst_0, %cst_1
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
|
||||
// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_int_mixing_1_n
|
||||
@ -226,8 +226,8 @@ func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> {
|
||||
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
|
||||
return %0 : tensor<2x2xi32>
|
||||
// CHECK: %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_splat_float
|
||||
@ -239,8 +239,8 @@ func @add_dense_splat_float() -> tensor<4xf32> {
|
||||
|
||||
return %2 : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_splat_dense_float
|
||||
@ -252,8 +252,8 @@ func @add_splat_dense_float() -> tensor<4xf32> {
|
||||
|
||||
return %2 : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_float_same_shape
|
||||
@ -265,8 +265,8 @@ func @add_dense_dense_float_same_shape() -> (tensor<4xf32>) {
|
||||
|
||||
return %2 : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_float_trailing_dim
|
||||
@ -281,10 +281,10 @@ func @add_dense_dense_float_trailing_dim() -> (tensor<2x2xf32>, tensor<2x2x2xf32
|
||||
|
||||
return %0, %1, %2 : tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %cst_0 = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
|
||||
// CHECK: %cst_1 = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
|
||||
// CHECK: return %cst, %cst_0, %cst_1
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
|
||||
// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_float_mixfng_1_n
|
||||
@ -296,24 +296,24 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
|
||||
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @rank
|
||||
func @rank() -> tensor<1xi32> {
|
||||
%cst = constant dense<[[1], [2]]> : tensor<2x1xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.rank"(%cst) : (tensor<2x1xi32>) -> tensor<1xi32>
|
||||
return %0 : tensor<1xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @rank_input_known_rank
|
||||
func @rank_input_known_rank(%arg0 : tensor<2x1xi32>) -> tensor<1xi32> {
|
||||
// CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.rank"(%arg0) : (tensor<2x1xi32>) -> tensor<1xi32>
|
||||
return %0 : tensor<1xi32>
|
||||
}
|
||||
@ -323,8 +323,8 @@ func @reshape() -> tensor<4xi32> {
|
||||
%input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
|
||||
%shape = constant dense<[4]> : tensor<1xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
@ -334,8 +334,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> {
|
||||
%input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
|
||||
%shape = constant dense<[4]> : tensor<1xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
@ -343,8 +343,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> {
|
||||
|
||||
// CHECK-LABEL: @pseudo_const
|
||||
func @pseudo_const() -> tensor<i32> {
|
||||
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
@ -356,8 +356,8 @@ func @range_int() -> tensor<?xi32> {
|
||||
%cst_1 = constant dense<4> : tensor<i32>
|
||||
%cst_2 = constant dense<1> : tensor<i32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
@ -368,8 +368,8 @@ func @range_float() -> tensor<?xf32> {
|
||||
%cst_1 = constant dense<4.0> : tensor<f32>
|
||||
%cst_2 = constant dense<1.0> : tensor<f32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
@ -381,8 +381,8 @@ func @range_float_neg_delta() -> tensor<?xf32> {
|
||||
%cst_1 = constant dense<-4.0> : tensor<f32>
|
||||
%cst_2 = constant dense<-1.0> : tensor<f32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
@ -393,8 +393,8 @@ func @range_float_nonzero_base() -> tensor<?xf32> {
|
||||
%cst_1 = constant dense<7.0> : tensor<f32>
|
||||
%cst_2 = constant dense<1.5> : tensor<f32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
@ -414,8 +414,8 @@ func @transpose_1d() -> tensor<3xi32> {
|
||||
%cst = constant dense<[1, 2, 3]> : tensor<3xi32>
|
||||
%cst_perm = constant dense<0> : tensor<1xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
@ -425,8 +425,8 @@ func @transpose_dynamic() -> tensor<?xi32> {
|
||||
%cst = constant dense<[1, 2, 3]> : tensor<3xi32>
|
||||
%cst_perm = constant dense<0> : tensor<1xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
@ -436,8 +436,8 @@ func @transpose_2d() -> tensor<2x2xi32> {
|
||||
%cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
|
||||
%cst_perm = constant dense<[1, 0]> : tensor<2xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -447,8 +447,8 @@ func @transpose_2d_identity() -> tensor<2x2xi32> {
|
||||
%cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
|
||||
%cst_perm = constant dense<[0, 1]> : tensor<2xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -460,8 +460,8 @@ func @transpose_3d() -> tensor<4x2x3xi32> {
|
||||
%cst = constant dense<[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]> : tensor<2x3x4xi32>
|
||||
%cst_perm = constant dense<[2, 0, 1]> : tensor<3xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x2x3xi32>
|
||||
return %0 : tensor<4x2x3xi32>
|
||||
}
|
||||
@ -473,8 +473,8 @@ func @ConstantFoldBinaryOpDynamicOutput() -> tensor<?xi32> {
|
||||
%87 = "tfl.sub"(%cst_0, %cst) {fused_activation_function = "NONE"} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
return %87 : tensor<?xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_int_same_shape_dynamic
|
||||
@ -486,8 +486,8 @@ func @add_dense_dense_int_same_shape_dynamic() -> tensor<?xi32> {
|
||||
|
||||
return %2 : tensor<?xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @concat_2_tensors_1_empty
|
||||
@ -497,8 +497,8 @@ func @concat_2_tensors_1_empty() -> tensor<2xi32> {
|
||||
%3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<0xi32>) -> tensor<2xi32>
|
||||
return %3 : tensor<2xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<2xi32>
|
||||
// CHECK: return [[cst]] : tensor<2xi32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<2xi32>
|
||||
// CHECK: return %[[CST]] : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @concat_3_tensors_1_empty
|
||||
@ -509,7 +509,7 @@ func @concat_3_tensors_1_empty() -> tensor<?xi32> {
|
||||
%3 = "tfl.concatenation"(%0, %1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<0xi32>) -> tensor<?xi32>
|
||||
return %3 : tensor<?xi32>
|
||||
|
||||
// CHECK: %0 = "tfl.concatenation"(%cst, %cst) {axis = 0 : i32, fused_activation_function = "NONE"}
|
||||
// CHECK: %0 = "tfl.concatenation"(%[[CST]], %[[CST]]) {axis = 0 : i32, fused_activation_function = "NONE"}
|
||||
// CHECK: return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
@ -520,10 +520,10 @@ func @concatConstantTensorsFirstDim() -> tensor<2x2x3xi32> {
|
||||
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<2x2x3xi32>
|
||||
return %0 : tensor<2x2x3xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32>
|
||||
// CHECK-NOT: constant-dense
|
||||
// CHECK-NOT: "tfl.concatenation"
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @concatConstantTensorsMiddleDim
|
||||
@ -533,10 +533,10 @@ func @concatConstantTensorsMiddleDim() -> tensor<1x4x3xi32> {
|
||||
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x4x3xi32>
|
||||
return %0 : tensor<1x4x3xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32>
|
||||
// CHECK-NOT: constant-dense
|
||||
// CHECK-NOT: "tfl.concatenation"
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @concatConstantTensorsLastDim
|
||||
@ -546,10 +546,10 @@ func @concatConstantTensorsLastDim() -> tensor<1x2x6xi32> {
|
||||
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 2 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x2x6xi32>
|
||||
return %0 : tensor<1x2x6xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32>
|
||||
// CHECK-NOT: constant-dense
|
||||
// CHECK-NOT: "tfl.concatenation"
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @div_dense_dense_float_mixfng_1_n
|
||||
@ -561,8 +561,8 @@ func @div_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
|
||||
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @div_dense_different_rank
|
||||
@ -574,6 +574,6 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> {
|
||||
|
||||
return %0 : tensor<1x2x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
@ -1048,6 +1048,15 @@ func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i64> } : () -> tensor<i64>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i64>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: concatv2I64Axis
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
@ -65,7 +65,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: opcode_index: 1,
|
||||
// CHECK-NEXT: inputs: [ 2, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: custom_options: [ 105, 110, 116, 95, 97, 116, 116, 114, 0, 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 2, 33, 43, 2, 1, 2, 11, 2, 20, 4, 4, 36, 1 ]
|
||||
// CHECK-NEXT: custom_options: [ 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 105, 110, 116, 95, 97, 116, 116, 114, 0, 2, 42, 11, 2, 1, 2, 20, 2, 20, 4, 4, 36, 1 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 2,
|
||||
// CHECK-NEXT: inputs: [ 3 ],
|
||||
|
@ -19,6 +19,16 @@ func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> (tensor<2xf32>,t
|
||||
// CHECK-NEXT: return %[[split]]#0, %[[split]]#1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: RemoveTrival
|
||||
func @RemoveTrival(%arg0: tensor<384x512x!quant.uniform<i8:f32, 1.0:-128>>, %arg1: tensor<128x512x!quant.uniform<i8<-127:127>:f32, 1.0>>, %arg2: none) -> tensor<384x128x!quant.uniform<i8:f32, 2.0>> {
|
||||
%1 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<384x512x!quant.uniform<i8:f32, 1.0:-128>>, tensor<128x512x!quant.uniform<i8<-127:127>:f32, 1.0>>, none) -> tensor<384x128x!quant.uniform<i8:f32, 1.0>>
|
||||
%2 = "tfl.quantize"(%1) {qtype = tensor<384x128x!quant.uniform<i8:f32, 2.0>>} : (tensor<384x128x!quant.uniform<i8:f32, 1.0>>) -> tensor<384x128x!quant.uniform<i8:f32, 2.0>>
|
||||
return %2 : tensor<384x128x!quant.uniform<i8:f32, 2.0>>
|
||||
|
||||
// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"{{.*}} -> tensor<384x128x!quant.uniform<i8:f32, 2.000000e+00>>
|
||||
// CHECK-NEXT: return %[[fc]]
|
||||
}
|
||||
|
||||
func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
|
||||
%cst = constant dense<[1, 1001]> : tensor<2xi32>
|
||||
%0 = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>
|
||||
|
@ -48,7 +48,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
quant_specs.default_ranges.second.hasValue()) {
|
||||
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
|
||||
quant_specs.default_ranges.first.getValueOr(0.0),
|
||||
quant_specs.default_ranges.second.getValueOr(0.0)));
|
||||
quant_specs.default_ranges.second.getValueOr(0.0),
|
||||
quant_specs.IsSignedInferenceType()));
|
||||
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
||||
|
@ -46,8 +46,11 @@ namespace {
|
||||
class DefaultQuantParamsPass
|
||||
: public PassWrapper<DefaultQuantParamsPass, FunctionPass> {
|
||||
public:
|
||||
explicit DefaultQuantParamsPass(double default_min, double default_max)
|
||||
: default_min_(default_min), default_max_(default_max) {}
|
||||
explicit DefaultQuantParamsPass(double default_min, double default_max,
|
||||
bool is_signed)
|
||||
: default_min_(default_min),
|
||||
default_max_(default_max),
|
||||
is_signed_(is_signed) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
@ -82,6 +85,7 @@ class DefaultQuantParamsPass
|
||||
|
||||
double default_min_;
|
||||
double default_max_;
|
||||
bool is_signed_;
|
||||
quant::QuantParams default_quant_params_;
|
||||
};
|
||||
} // namespace
|
||||
@ -214,15 +218,16 @@ quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
|
||||
default_quant_params_ = quant::fakeQuantAttrsToType(
|
||||
builder.getUnknownLoc(),
|
||||
/*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false,
|
||||
builder.getF32Type());
|
||||
builder.getF32Type(), is_signed_);
|
||||
}
|
||||
return default_quant_params_;
|
||||
}
|
||||
|
||||
// Creates an instance of the default quant parameters pass.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
|
||||
double default_min, double default_max) {
|
||||
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max);
|
||||
double default_min, double default_max, bool is_signed) {
|
||||
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max,
|
||||
is_signed);
|
||||
}
|
||||
|
||||
// Registers this pass with default values, only for test
|
||||
@ -230,7 +235,8 @@ static PassRegistration<DefaultQuantParamsPass> pass(
|
||||
"tfl-default-quant",
|
||||
"Apply quantization with default quantization parameter", [] {
|
||||
return CreateDefaultQuantParamsPass(/*default_min=*/-1.0,
|
||||
/*default_max=*/1.0);
|
||||
/*default_max=*/1.0,
|
||||
/*is_signed=*/false);
|
||||
});
|
||||
|
||||
} // namespace TFL
|
||||
|
@ -321,7 +321,8 @@ void DenseToSparse::runOnFunction() {
|
||||
|
||||
if (result.needs_densify) {
|
||||
const auto value = op->getOperand(operand);
|
||||
auto densify = builder.create<DensifyOp>(op->getLoc(), value);
|
||||
auto densify =
|
||||
builder.create<DensifyOp>(op->getLoc(), value.getType(), value);
|
||||
value.replaceAllUsesWith(densify);
|
||||
densify.setOperand(value);
|
||||
}
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
@ -202,6 +203,26 @@ LogicalResult ConvertTFConcatOp::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
// Converts any IntegerAttr to an IntegerAttr of an i32 type.
|
||||
// The value won't change in the new attribute, but if the value is out of
|
||||
// the bound of i32, the function returns a failure.
|
||||
LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
|
||||
if (attr.getType().isInteger(/*width=*/32)) {
|
||||
*attr_i32 = attr;
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t value = attr.getInt();
|
||||
if (value > std::numeric_limits<int>::max() ||
|
||||
value < std::numeric_limits<int>::min()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
*attr_i32 = IntegerAttr::get(
|
||||
IntegerType::get(/*width=*/32, attr.getContext()), value);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
|
||||
@ -211,12 +232,16 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
// Extract axis attribute from constant axis tensor
|
||||
ElementsAttr axis;
|
||||
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure();
|
||||
IntegerAttr axis_int = ExtractSingleElementAsInteger(axis);
|
||||
|
||||
// "axis" operand could be a i64 tensor. Resolve it here.
|
||||
IntegerAttr axis_i32;
|
||||
if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure();
|
||||
|
||||
StringAttr fused_activation_function =
|
||||
StringAttr::get("NONE", rewriter.getContext());
|
||||
rewriter.replaceOpWithNewOp<ConcatenationOp>(
|
||||
op, output_type, values, ExtractSingleElementAsInteger(axis),
|
||||
fused_activation_function);
|
||||
op, output_type, values, axis_i32, fused_activation_function);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -859,6 +859,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
target.addLegalOp<ConstantOp>();
|
||||
target.addLegalOp<FuncOp>();
|
||||
target.addLegalOp<ReturnOp>();
|
||||
target.addLegalOp<TFL::CustomOp>();
|
||||
// Register fused LSTM/RNN ops as legal.
|
||||
target.addLegalOp<TFL::LSTMOp>();
|
||||
target.addLegalOp<TFL::UnidirectionalSequenceLSTMOp>();
|
||||
|
@ -76,7 +76,7 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass();
|
||||
// Creates an instance of the TensorFlow Lite dialect pass to add default
|
||||
// quantization parameters.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
|
||||
double default_min, double default_max);
|
||||
double default_min, double default_max, bool is_signed);
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect pass to convert dense
|
||||
// tensor to sparse format.
|
||||
|
@ -125,6 +125,7 @@ void PostQuantizePass::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
auto* ctx = func.getContext();
|
||||
TFL::populateWithGenerated(ctx, &patterns);
|
||||
patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
|
||||
if (!emit_quant_adaptor_ops_) {
|
||||
|
@ -70,6 +70,7 @@ class PrepareQuantizePass
|
||||
: public PassWrapper<PrepareQuantizePass, FunctionPass> {
|
||||
public:
|
||||
// Constructor used by the PassRegistration and enforce uint8 quantization.
|
||||
// This is only used by test.
|
||||
explicit PrepareQuantizePass() {
|
||||
if (quantize_signed)
|
||||
quant_specs_.inference_type = tensorflow::DT_QINT8;
|
||||
@ -257,15 +258,16 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
// convert all of them to signed.
|
||||
OwningRewritePatternList patterns;
|
||||
bool is_signed = quant_specs_.IsSignedInferenceType();
|
||||
int bit_width = quant_specs_.GetQuantizationTypeWidth();
|
||||
if (is_signed) {
|
||||
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
|
||||
// Convert quant stats to int8 quantization parameters.
|
||||
// Currently, only activation stats are imported, so narrow_range = false.
|
||||
patterns.insert<PrepareQuantStats>(8, false, true, ctx);
|
||||
patterns.insert<PrepareQuantStats>(bit_width, false, true, ctx);
|
||||
} else {
|
||||
// Convert quant stats to uint8 quantization parameters.
|
||||
// Currently, only activation stats are imported, so narrow_range = false.
|
||||
patterns.insert<PrepareQuantStats>(8, false, false, ctx);
|
||||
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
|
||||
}
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
|
||||
|
41
tensorflow/compiler/mlir/python/mlir_wrapper/BUILD
Normal file
41
tensorflow/compiler/mlir/python/mlir_wrapper/BUILD
Normal file
@ -0,0 +1,41 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "mlir_wrapper",
|
||||
srcs = [
|
||||
"attrs.cc",
|
||||
"basic_classes.cc",
|
||||
"builders.cc",
|
||||
"mlir_wrapper.cc",
|
||||
"mlir_wrapper.h",
|
||||
"ops.cc",
|
||||
"types.cc",
|
||||
],
|
||||
module_name = "mlir_wrapper",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//tensorflow/python:pybind11_status",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "filecheck_wrapper",
|
||||
srcs = ["filecheck_wrapper.cc"],
|
||||
module_name = "filecheck_wrapper",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//tensorflow/python:pybind11_status",
|
||||
"@llvm-project//llvm:support",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
25
tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc
Normal file
25
tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc
Normal file
@ -0,0 +1,25 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
|
||||
void init_attrs(py::module& m) {
|
||||
py::class_<mlir::Attribute>(m, "Attribute");
|
||||
py::class_<mlir::IntegerAttr, mlir::Attribute>(m, "IntegerAttr")
|
||||
.def("get",
|
||||
py::overload_cast<mlir::Type, int64_t>(&mlir::IntegerAttr::get));
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/FileCheck.h"
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Region.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
|
||||
void init_basic_classes(py::module& m) {
|
||||
py::class_<mlir::MLIRContext>(m, "MLIRContext").def(py::init<>());
|
||||
|
||||
py::class_<mlir::Location>(m, "Location");
|
||||
|
||||
py::class_<mlir::UnknownLoc>(m, "UnknownLoc")
|
||||
.def("get", &mlir::UnknownLoc::get);
|
||||
|
||||
py::class_<mlir::Region>(m, "Region")
|
||||
.def("back", &mlir::Region::back, py::return_value_policy::reference)
|
||||
.def("front", &mlir::Region::front, py::return_value_policy::reference)
|
||||
.def("add_block", [](mlir::Region& r) { r.push_back(new mlir::Block); })
|
||||
.def("push_back", &mlir::Region::push_back)
|
||||
.def("size", [](mlir::Region& r) { return r.getBlocks().size(); })
|
||||
.def("front", &mlir::Region::front, py::return_value_policy::reference);
|
||||
py::class_<mlir::Block::iterator>(m, "Block_Iterator");
|
||||
py::class_<mlir::Block>(m, "Block")
|
||||
.def("new", ([]() { return new mlir::Block; }),
|
||||
py::return_value_policy::reference)
|
||||
.def("end", &mlir::Block::end)
|
||||
.def("addArgument", &mlir::Block::addArgument);
|
||||
|
||||
py::class_<mlir::Value>(m, "Value").def("getType", &mlir::Value::getType);
|
||||
py::class_<mlir::OpResult, mlir::Value>(m, "OpResult");
|
||||
py::class_<mlir::BlockArgument, mlir::Value>(m, "BlockArgument");
|
||||
}
|
51
tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc
Normal file
51
tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
|
||||
void init_builders(py::module& m) {
|
||||
py::class_<mlir::Builder>(m, "Builder")
|
||||
.def(py::init<mlir::MLIRContext*>())
|
||||
.def("getFunctionType",
|
||||
[](mlir::Builder& b, std::vector<mlir::Type> inputs,
|
||||
std::vector<mlir::Type> outputs) {
|
||||
return b.getFunctionType(llvm::ArrayRef<mlir::Type>(inputs),
|
||||
llvm::ArrayRef<mlir::Type>(outputs));
|
||||
});
|
||||
py::class_<mlir::OpBuilder>(m, "OpBuilder")
|
||||
.def(py::init<mlir::MLIRContext*>())
|
||||
.def(py::init<mlir::Region&>())
|
||||
.def(py::init<mlir::Operation*>())
|
||||
.def(py::init<mlir::Block*, mlir::Block::iterator>())
|
||||
.def("getUnknownLoc", &mlir::OpBuilder::getUnknownLoc)
|
||||
.def("setInsertionPoint",
|
||||
py::overload_cast<mlir::Block*, mlir::Block::iterator>(
|
||||
&mlir::OpBuilder::setInsertionPoint))
|
||||
.def("saveInsertionPoint", &mlir::OpBuilder::saveInsertionPoint)
|
||||
.def("restoreInsertionPoint", &mlir::OpBuilder::restoreInsertionPoint)
|
||||
.def(
|
||||
"createOperation",
|
||||
[](mlir::OpBuilder& opb, mlir::OperationState& state) {
|
||||
return opb.createOperation(state);
|
||||
},
|
||||
py::return_value_policy::reference)
|
||||
.def("getContext", &mlir::OpBuilder::getContext,
|
||||
py::return_value_policy::reference);
|
||||
|
||||
py::class_<mlir::OpBuilder::InsertPoint>(m, "OpBuilder_InsertionPoint")
|
||||
.def("getBlock", &mlir::OpBuilder::InsertPoint::getBlock);
|
||||
}
|
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/FileCheck.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
PYBIND11_MODULE(filecheck_wrapper, m) {
|
||||
m.def("check", [](std::string input, std::string check) {
|
||||
llvm::FileCheckRequest fcr;
|
||||
llvm::FileCheck fc(fcr);
|
||||
llvm::SourceMgr SM = llvm::SourceMgr();
|
||||
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
|
||||
llvm::SMLoc());
|
||||
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(check),
|
||||
llvm::SMLoc());
|
||||
llvm::Regex regex = fc.buildCheckPrefixRegex();
|
||||
fc.readCheckFile(SM, llvm::StringRef(check), regex);
|
||||
return fc.checkInput(SM, llvm::StringRef(input));
|
||||
});
|
||||
}
|
38
tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
Normal file
38
tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
PYBIND11_MODULE(mlir_wrapper, m) {
|
||||
m.def("registerDialects", []() {
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
});
|
||||
|
||||
init_basic_classes(m);
|
||||
init_types(m);
|
||||
init_builders(m);
|
||||
init_ops(m);
|
||||
init_attrs(m);
|
||||
}
|
@ -13,19 +13,18 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
//===- dialect_static_registration.cc -------------------------------------===//
|
||||
//
|
||||
// This file registers the RuntimeFallbackDialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
|
||||
#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
|
||||
|
||||
#include "tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tfd {
|
||||
namespace py = pybind11;
|
||||
|
||||
// Static initialization for dialect registration.
|
||||
static DialectRegistration<RuntimeFallbackDialect> tfd_registration;
|
||||
void init_basic_classes(py::module& m);
|
||||
void init_types(py::module& m);
|
||||
void init_builders(py::module& m);
|
||||
void init_ops(py::module& m);
|
||||
void init_attrs(py::module& m);
|
||||
|
||||
} // namespace tfd
|
||||
} // namespace mlir
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
|
194
tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc
Normal file
194
tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc
Normal file
@ -0,0 +1,194 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
void init_ops(py::module& m) {
|
||||
py::class_<mlir::Operation, std::unique_ptr<mlir::Operation, py::nodelete>>(
|
||||
m, "Operation")
|
||||
.def("getRegion", &mlir::Operation::getRegion,
|
||||
py::return_value_policy::reference)
|
||||
.def("getResult", &mlir::Operation::getResult)
|
||||
.def("dump", &mlir::Operation::dump)
|
||||
.def("getNumResults", &mlir::Operation::getNumResults);
|
||||
|
||||
py::class_<mlir::OperationState>(m, "OperationState")
|
||||
.def(py::init([](mlir::Location loc, std::string name) {
|
||||
return mlir::OperationState(loc, llvm::StringRef(name));
|
||||
}))
|
||||
.def("addTypes",
|
||||
[](mlir::OperationState& state, std::vector<mlir::Type> tys) {
|
||||
state.addTypes(mlir::ArrayRef<mlir::Type>(tys));
|
||||
})
|
||||
.def("addOperands",
|
||||
[](mlir::OperationState& os, std::vector<mlir::Value> ops) {
|
||||
os.addOperands(mlir::ArrayRef<mlir::Value>(ops));
|
||||
})
|
||||
.def("addRegion", py::overload_cast<>(&mlir::OperationState::addRegion),
|
||||
py::return_value_policy::reference);
|
||||
|
||||
py::class_<mlir::ModuleOp>(m, "ModuleOp")
|
||||
.def("create",
|
||||
[](mlir::Location loc) { return mlir::ModuleOp::create(loc); })
|
||||
.def("push_back",
|
||||
[](mlir::ModuleOp& m, mlir::FuncOp f) { m.push_back(f); })
|
||||
.def("dump", &mlir::ModuleOp::dump)
|
||||
.def("getAsStr", [](mlir::ModuleOp& m) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
m.print(os);
|
||||
return os.str();
|
||||
});
|
||||
|
||||
py::class_<mlir::FuncOp>(m, "FuncOp")
|
||||
.def("create",
|
||||
[](mlir::Location location, std::string name,
|
||||
mlir::FunctionType type) {
|
||||
auto func = mlir::FuncOp::create(location, name, type);
|
||||
func.addEntryBlock();
|
||||
return func;
|
||||
})
|
||||
.def(
|
||||
"getBody",
|
||||
[](mlir::FuncOp& f) -> mlir::Region& { return f.getBody(); },
|
||||
py::return_value_policy::reference)
|
||||
.def("getArguments",
|
||||
[](mlir::FuncOp& f) { return f.getArguments().vec(); })
|
||||
.def("getName", [](mlir::FuncOp& f) { return f.getName().str(); })
|
||||
.def("getType", &mlir::FuncOp::getType);
|
||||
|
||||
py::class_<mlir::ReturnOp>(m, "ReturnOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||
std::vector<mlir::Value> values) -> mlir::Operation* {
|
||||
return opb
|
||||
.create<mlir::ReturnOp>(loc,
|
||||
mlir::ArrayRef<mlir::Value>(values))
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::AddOp
|
||||
py::class_<mlir::TF::AddV2Op>(m, "Tf_AddV2Op")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::AddV2Op>(loc, x, y).getOperation();
|
||||
});
|
||||
|
||||
py::class_<mlir::TF::AnyOp>(m, "Tf_AnyOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value input,
|
||||
mlir::Value reduction_indices,
|
||||
bool keep_dims = false) -> mlir::Operation* {
|
||||
return opb
|
||||
.create<mlir::TF::AnyOp>(loc, opb.getI1Type(), input,
|
||||
reduction_indices, keep_dims)
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::ConstOp
|
||||
py::class_<mlir::TF::ConstOp>(m, "Tf_ConstOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||
mlir::Attribute value) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::ConstOp>(loc, value).getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::EqualOp
|
||||
py::class_<mlir::TF::EqualOp>(m, "Tf_EqualOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb
|
||||
.create<mlir::TF::EqualOp>(loc, x, y, opb.getBoolAttr(true))
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::GreaterEqualOp
|
||||
py::class_<mlir::TF::GreaterEqualOp>(m, "Tf_GreaterEqualOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::GreaterEqualOp>(loc, x, y)
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::GreaterOp
|
||||
py::class_<mlir::TF::GreaterOp>(m, "Tf_GreaterOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::GreaterOp>(loc, x, y).getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::LegacyCallOp
|
||||
py::class_<mlir::TF::LegacyCallOp>(m, "Tf_LegacyCallOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||
std::vector<mlir::Type> output, std::vector<mlir::Value> args,
|
||||
std::string f) -> mlir::Operation* {
|
||||
return opb
|
||||
.create<mlir::TF::LegacyCallOp>(
|
||||
loc, mlir::ArrayRef<mlir::Type>(output),
|
||||
mlir::ArrayRef<mlir::Value>(args), mlir::StringRef(f))
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::LessEqualOp
|
||||
py::class_<mlir::TF::LessEqualOp>(m, "Tf_LessEqualOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::LessEqualOp>(loc, x, y).getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::LessOp
|
||||
py::class_<mlir::TF::LessOp>(m, "Tf_LessOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::LessOp>(loc, x, y).getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::NegOp
|
||||
py::class_<mlir::TF::NegOp>(m, "Tf_NegOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||
mlir::Value x) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::NegOp>(loc, x).getOperation();
|
||||
});
|
||||
|
||||
py::class_<mlir::TF::NotEqualOp>(m, "Tf_NotEqualOp")
|
||||
.def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) {
|
||||
return opb
|
||||
.create<mlir::TF::NotEqualOp>(
|
||||
loc, x, y, mlir::BoolAttr::get(true, opb.getContext()))
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::SubOp
|
||||
py::class_<mlir::TF::SubOp>(m, "Tf_SubOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::SubOp>(loc, x, y).getOperation();
|
||||
});
|
||||
}
|
48
tensorflow/compiler/mlir/python/mlir_wrapper/types.cc
Normal file
48
tensorflow/compiler/mlir/python/mlir_wrapper/types.cc
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
void init_types(py::module& m) {
|
||||
// Type
|
||||
py::class_<mlir::Type> Type(m, "Type");
|
||||
Type.def("getKind", &mlir::Type::getKind);
|
||||
|
||||
// Type Enums
|
||||
py::enum_<mlir::StandardTypes::Kind>(Type, "StandardTypes_Kind")
|
||||
.value("BF16", mlir::StandardTypes::BF16);
|
||||
|
||||
// Type Sub-classes
|
||||
py::class_<mlir::FunctionType, mlir::Type>(m, "FunctionType")
|
||||
.def("getResults",
|
||||
[](mlir::FunctionType& ft) { return ft.getResults().vec(); });
|
||||
|
||||
py::class_<mlir::FloatType, mlir::Type>(m, "FloatType")
|
||||
.def("get", &mlir::FloatType::get);
|
||||
|
||||
py::class_<mlir::IntegerType, mlir::Type>(m, "IntegerType")
|
||||
.def("get", py::overload_cast<unsigned, mlir::MLIRContext*>(
|
||||
&mlir::IntegerType::get));
|
||||
|
||||
py::class_<mlir::UnrankedTensorType, mlir::Type>(m, "UnrankedTensorType")
|
||||
.def("get", &mlir::UnrankedTensorType::get);
|
||||
|
||||
py::class_<mlir::RankedTensorType, mlir::Type>(m, "RankedTensorType")
|
||||
.def("get", [](std::vector<int64_t> shape, mlir::Type ty) {
|
||||
return mlir::RankedTensorType::get(mlir::ArrayRef<int64_t>(shape), ty);
|
||||
});
|
||||
}
|
@ -70,9 +70,9 @@ tool_dirs = config.mlir_tf_tools_dirs + [
|
||||
]
|
||||
tool_names = [
|
||||
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
||||
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
|
||||
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt',
|
||||
'xla-opt'
|
||||
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
||||
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
||||
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt'
|
||||
]
|
||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
@ -44,6 +44,7 @@ mlir_tf_tools_dirs = [
|
||||
'tensorflow/compiler/mlir',
|
||||
'tensorflow/compiler/mlir/lite',
|
||||
'tensorflow/compiler/mlir/tensorflow',
|
||||
'tensorflow/compiler/mlir/tfjs',
|
||||
'tensorflow/compiler/mlir/xla',
|
||||
'tensorflow/compiler/aot',
|
||||
'tensorflow/compiler/xla/service/mlir_gpu',
|
||||
|
@ -36,7 +36,7 @@ filegroup(
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1075,7 +1075,7 @@ genrule(
|
||||
srcs = [
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||
"ir/tf_generated_ops.td",
|
||||
"ir/tf_op_base.td",
|
||||
@ -1140,6 +1140,7 @@ COMPILE_MLIR_UTIL_DEPS = [
|
||||
"//tensorflow/compiler/mlir/xla:type_to_shape",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/mlir/xla:xla_sink_constants_to_control_flow",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/core:framework",
|
||||
@ -1278,6 +1279,7 @@ cc_library(
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1292,6 +1294,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFControlFlow {
|
||||
|
@ -1765,7 +1765,7 @@ of corresponding 3-element vectors is cross-multiplied independently.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [AllTypesMatch<["input", "output"]>, NoSideEffect]> {
|
||||
def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> {
|
||||
let summary = "An Op to sum inputs across replicated TPU instances.";
|
||||
|
||||
let description = [{
|
||||
@ -1789,7 +1789,7 @@ and `B, D, F, H` as group 1. Thus we get the outputs:
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CumsumOp : TF_Op<"Cumsum", [AllTypesMatch<["x", "out"]>, NoSideEffect]> {
|
||||
def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
|
||||
let summary = "Compute the cumulative sum of the tensor `x` along `axis`.";
|
||||
|
||||
let description = [{
|
||||
@ -2907,6 +2907,8 @@ fill([2, 3], 9) ==> [[9, 9, 9]
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value dims, Value value"
|
||||
>];
|
||||
@ -3625,31 +3627,6 @@ tf.imag(input) ==> [4.75, 5.75]
|
||||
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Create a copy of `x` with the updated specified rows 'i' with values 'v'.
|
||||
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Creates a copy of tensor 'x' and updates the columns specified in tensor 'i'
|
||||
with the values 'v'. Originally this function was mutative however for
|
||||
compilation we make this operation create / operate on a copy.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$x,
|
||||
I32Tensor:$i,
|
||||
TF_Tensor:$v
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_InvOp : TF_Op<"Inv", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes the reciprocal of x element-wise.";
|
||||
|
||||
@ -4350,7 +4327,7 @@ cublas.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [AllTypesMatch<["input", "band"]>, NoSideEffect]> {
|
||||
def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, TF_AllTypesMatch<["input", "band"]>]> {
|
||||
let summary = [{
|
||||
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
||||
}];
|
||||
@ -6354,6 +6331,8 @@ If `x` and `y` are reals, this will return the floating-point division.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
@ -10352,6 +10331,33 @@ https://www.tensorflow.org/xla/operation_semantics#gather
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaHostComputeOp : TF_Op<"XlaHostCompute", []> {
|
||||
let summary = [{
|
||||
A pseudo-op to represent host-side computation in an XLA program.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$inputs,
|
||||
|
||||
StrArrayAttr:$ancestors,
|
||||
TF_ShapeAttrArray:$shapes,
|
||||
SymbolRefAttr:$shape_inference_graph,
|
||||
StrAttr:$key,
|
||||
DefaultValuedAttr<I64Attr, "1000000">:$cost_estimate_ns,
|
||||
DefaultValuedAttr<I64Attr, "0">:$tpu_core
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$outputs
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaKeyValueSortOp : TF_Op<"XlaKeyValueSort", [NoSideEffect]> {
|
||||
let summary = "Wraps the XLA Sort operator, documented at";
|
||||
|
||||
@ -10400,6 +10406,24 @@ https://www.tensorflow.org/performance/xla/operation_semantics#pad
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaRecvFromHostOp : TF_Op<"XlaRecvFromHost", []> {
|
||||
let summary = "An op to receive a tensor from the host.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_ShapeAttr:$shape,
|
||||
StrAttr:$key
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedResultTypeAttr Toutput = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaReduceOp : TF_Op<"XlaReduce", [NoSideEffect]> {
|
||||
let summary = "Wraps the XLA Reduce operator, documented at";
|
||||
|
||||
@ -10464,6 +10488,23 @@ i=0...N-1.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaSendToHostOp : TF_Op<"XlaSendToHost", []> {
|
||||
let summary = "An op to send a tensor to the host.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input,
|
||||
|
||||
StrAttr:$key
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tinput = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaSvdOp : TF_Op<"XlaSvd", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Computes the eigen decomposition of a batch of self-adjoint matrices
|
||||
@ -10547,6 +10588,27 @@ def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> {
|
||||
let summary = "A host-side computation called from a TPU device.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$inputs,
|
||||
|
||||
StrAttr:$key,
|
||||
DefaultValuedAttr<I64Attr, "0">:$tpu_core
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$outputs
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", []> {
|
||||
let summary = "An op that receives embeddng activations on the TPU.";
|
||||
|
||||
@ -10605,3 +10667,44 @@ used to look up the program in the compilation cache.
|
||||
TF_DerivedResultSizeAttr num_computations = TF_DerivedResultSizeAttr<1>;
|
||||
TF_DerivedOperandSizeAttr NumDynamicShapes = TF_DerivedOperandSizeAttr<0>;
|
||||
}
|
||||
|
||||
def TF__XlaRecvAtHostOp : TF_Op<"_XlaRecvAtHost", []> {
|
||||
let summary = [{
|
||||
A placeholder op to receive values from a running XLA computation.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_StrTensor:$dynamic_key,
|
||||
|
||||
StrAttr:$key,
|
||||
I64Attr:$device_ordinal
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$outputs
|
||||
);
|
||||
|
||||
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF__XlaSendFromHostOp : TF_Op<"_XlaSendFromHost", []> {
|
||||
let summary = "A placeholder op to send values to a running XLA computation.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$inputs,
|
||||
TF_StrTensor:$dynamic_key,
|
||||
|
||||
StrAttr:$key,
|
||||
I64Attr:$device_ordinal
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#define TF_OP_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -70,6 +70,16 @@ class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
|
||||
"$_op.getOperand(" # opId # ").getType(), "
|
||||
"$_op.getResult(" # resId # ").getType())">]>;
|
||||
|
||||
|
||||
class TF_AllTypesMatchPred<list<string> values> :
|
||||
CPred<"TF::AreCastCompatible(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">;
|
||||
|
||||
class TF_AllTypesMatch<list<string> names> :
|
||||
PredOpTrait<
|
||||
"all of {" # StrJoin<names>.result # "} have dynamically equal types ",
|
||||
TF_AllTypesMatchPred<
|
||||
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -110,48 +110,6 @@ static inline bool HasRankAtMost(Value value, int64_t rank) {
|
||||
return !type || type.getRank() <= rank;
|
||||
}
|
||||
|
||||
// Returns true if the given pair of TensorFlow types can be cast to one
|
||||
// another. In other words, a single run-time value is legal for both the types.
|
||||
// For example, tensor<*xf32> and tensor<3xf32> are cast compatible.
|
||||
static bool AreCastCompatible(Type a, Type b) {
|
||||
if (TensorCastOp::areCastCompatible(a, b)) return true;
|
||||
|
||||
// Resource types may optionally contain subtypes information that does not
|
||||
// match. Check subtypes compatibility when possible, otherwise treat them as
|
||||
// compatible.
|
||||
auto a_or_element_type = getElementTypeOrSelf(a);
|
||||
auto b_or_element_type = getElementTypeOrSelf(b);
|
||||
|
||||
auto a_kind = a_or_element_type.getKind();
|
||||
auto b_kind = b_or_element_type.getKind();
|
||||
|
||||
if (a_kind == TensorFlowTypes::RESOURCE &&
|
||||
b_kind == TensorFlowTypes::RESOURCE) {
|
||||
auto a_resource_type = a_or_element_type.dyn_cast<ResourceType>();
|
||||
auto b_resource_type = b_or_element_type.dyn_cast<ResourceType>();
|
||||
bool a_has_subtype = !a_resource_type.getSubtypes().empty();
|
||||
bool b_has_subtype = !b_resource_type.getSubtypes().empty();
|
||||
|
||||
if (!a_has_subtype || !b_has_subtype) return true;
|
||||
|
||||
assert(a_resource_type.getSubtypes().size() <= 1 &&
|
||||
"Resource type must have at most one subtype");
|
||||
assert(b_resource_type.getSubtypes().size() <= 1 &&
|
||||
"Resource type must have at most one subtype");
|
||||
|
||||
return TensorCastOp::areCastCompatible(
|
||||
a_resource_type.getSubtypes().front(),
|
||||
b_resource_type.getSubtypes().front());
|
||||
}
|
||||
|
||||
// Variant types may optionally contain subtypes information that need not
|
||||
// match. It is also not possible to compare subtypes for compatibility as
|
||||
// their interpretation depends on the ops operating on them. So, accept all
|
||||
// pairs of variant types.
|
||||
return a_kind == TensorFlowTypes::VARIANT &&
|
||||
b_kind == TensorFlowTypes::VARIANT;
|
||||
}
|
||||
|
||||
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
|
||||
return dim_or_rank == -1;
|
||||
}
|
||||
@ -503,9 +461,10 @@ LogicalResult FoldOperandsPermutation(
|
||||
namespace {
|
||||
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
|
||||
// known to be Identity (e.g X+0)
|
||||
template <typename OpT,
|
||||
typename std::enable_if<llvm::is_one_of<
|
||||
OpT, AddV2Op, SubOp, MulOp, DivOp>::value>::type * = nullptr>
|
||||
template <
|
||||
typename OpT,
|
||||
typename std::enable_if<llvm::is_one_of<
|
||||
OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr>
|
||||
OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
|
||||
ArrayRef<Attribute> operands) {
|
||||
auto result_op_type = arithmetic_op.getResult().getType();
|
||||
@ -520,7 +479,8 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
|
||||
// Mul and Div ops have identity value one while AddV2 and SubOp have identity
|
||||
// value zero.
|
||||
int identity =
|
||||
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value);
|
||||
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value ||
|
||||
std::is_same<OpT, RealDivOp>::value);
|
||||
|
||||
Type element_ty = lhs_type.getElementType();
|
||||
Attribute identity_attr;
|
||||
@ -537,6 +497,12 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
|
||||
return arithmetic_op.x();
|
||||
}
|
||||
|
||||
auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
|
||||
// TODO(chhe): we could fold and add an identity to force the broadcast.
|
||||
if (result_op_type != rhs_type) {
|
||||
return {};
|
||||
}
|
||||
|
||||
bool is_symmetric =
|
||||
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
|
||||
if (auto attr = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
@ -1413,7 +1379,7 @@ static LogicalResult Verify(DynamicStitchOp op) {
|
||||
auto expected_out_ty =
|
||||
RankedTensorType::get(expected_shape, out_ty.getElementType());
|
||||
|
||||
if (!AreCastCompatible(out_ty, expected_out_ty)) {
|
||||
if (!AreCastCompatible({out_ty, expected_out_ty})) {
|
||||
return op.emitOpError() << "has invalid output type; should be "
|
||||
"compatible with inferred type "
|
||||
<< expected_out_ty;
|
||||
@ -1647,7 +1613,7 @@ static ShapedType InferFillOpType(Value dims, Value value) {
|
||||
|
||||
llvm::SmallVector<int64_t, 4> shape;
|
||||
shape.reserve(dims_attr.getNumElements());
|
||||
for (const APInt &dim : dims_attr.getValues<APInt>()) {
|
||||
for (const APInt dim : dims_attr.getValues<APInt>()) {
|
||||
shape.push_back(dim.getSExtValue());
|
||||
}
|
||||
return RankedTensorType::get(shape, etype);
|
||||
@ -1658,6 +1624,29 @@ void FillOp::build(OpBuilder &builder, OperationState &result, Value dims,
|
||||
FillOp::build(builder, result, InferFillOpType(dims, value), dims, value);
|
||||
}
|
||||
|
||||
OpFoldResult FillOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "fill op has two operand");
|
||||
|
||||
auto value = operands[1].dyn_cast_or_null<ElementsAttr>();
|
||||
if (!value) return {};
|
||||
|
||||
auto type = getType().cast<ShapedType>();
|
||||
if (type.hasStaticShape())
|
||||
return DenseElementsAttr::get(type, value.getValue({}));
|
||||
|
||||
auto dims = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
if (!dims) return {};
|
||||
|
||||
llvm::SmallVector<int64_t, 4> shape;
|
||||
shape.reserve(dims.getNumElements());
|
||||
for (const APInt dim : dims.getValues<APInt>()) {
|
||||
shape.push_back(dim.getSExtValue());
|
||||
}
|
||||
type = RankedTensorType::get(shape, type.getElementType());
|
||||
|
||||
return DenseElementsAttr::get(type, value.getValue({}));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FusedBatchNormGradOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1814,14 +1803,14 @@ static LogicalResult Verify(IfOp op) {
|
||||
for (unsigned i = 0; i < expectedNumInputs; ++i) {
|
||||
auto operandType = op.getOperand(i + 1).getType().cast<TensorType>();
|
||||
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
|
||||
if (!AreCastCompatible(operandType, thenInputType))
|
||||
if (!AreCastCompatible({operandType, thenInputType}))
|
||||
return op.emitError(
|
||||
llvm::formatv("then branch input type {0} is incompatible with "
|
||||
"operand type {1} at index {2}",
|
||||
thenInputType, operandType, i));
|
||||
|
||||
auto elseInputType = elseFuncType.getInput(i).cast<TensorType>();
|
||||
if (!AreCastCompatible(operandType, elseInputType))
|
||||
if (!AreCastCompatible({operandType, elseInputType}))
|
||||
return op.emitError(
|
||||
llvm::formatv("else branch input type {0} is incompatible with "
|
||||
"operand type {1} at index {2}",
|
||||
@ -1829,7 +1818,7 @@ static LogicalResult Verify(IfOp op) {
|
||||
|
||||
// If branches have incompatible input types that means that no tensor can
|
||||
// serve as input to both the functions. Hence, the op is invalid.
|
||||
if (!AreCastCompatible(thenInputType, elseInputType))
|
||||
if (!AreCastCompatible({thenInputType, elseInputType}))
|
||||
return op.emitError(llvm::formatv(
|
||||
"branches inputs have incompatible types {0} and {1} at index {2}",
|
||||
thenInputType, elseInputType, i));
|
||||
@ -1845,14 +1834,14 @@ static LogicalResult Verify(IfOp op) {
|
||||
for (unsigned i = 0; i < expectedNumResults; ++i) {
|
||||
auto resultType = op.getResult(i).getType().cast<TensorType>();
|
||||
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
|
||||
if (!AreCastCompatible(thenResultType, resultType))
|
||||
if (!AreCastCompatible({thenResultType, resultType}))
|
||||
return op.emitError(
|
||||
llvm::formatv("then branch result type {0} is incompatible with op "
|
||||
"result type {1} at index {2}",
|
||||
thenResultType, resultType, i));
|
||||
|
||||
auto elseResultType = elseFuncType.getResult(i).cast<TensorType>();
|
||||
if (!AreCastCompatible(elseResultType, resultType))
|
||||
if (!AreCastCompatible({elseResultType, resultType}))
|
||||
return op.emitError(
|
||||
llvm::formatv("else branch result type {0} is incompatible with op "
|
||||
"result type {1} at index {2}",
|
||||
@ -2426,6 +2415,10 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<RealDivWithSqrtDivisor>(context);
|
||||
}
|
||||
|
||||
OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
|
||||
return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2616,9 +2609,12 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
|
||||
<< variadic_idx_str << " to match rank of operand"
|
||||
<< variadic_idx_str;
|
||||
} else if (result_ranked_type.hasStaticShape()) {
|
||||
// The operand is an unranked tensor, verify that the result is dynamic.
|
||||
return op->emitOpError("requires dynamic shape result")
|
||||
<< variadic_idx_str << " for unranked operand" << variadic_idx_str;
|
||||
// The operand is an unranked tensor, print a warning if the result
|
||||
// is static.
|
||||
// Note: We do not handle this situation as an error, this would be too
|
||||
// restrictive due to incompleteness of shape inference at this point.
|
||||
op->emitWarning("has static shape result")
|
||||
<< variadic_idx_str << " for unranked operand" << variadic_idx_str;
|
||||
}
|
||||
|
||||
Type element_type = result_ranked_type.getElementType();
|
||||
@ -3789,7 +3785,7 @@ static LogicalResult Verify(WhileOp op) {
|
||||
auto aType = a.second[idx];
|
||||
auto bType = b.second[idx];
|
||||
|
||||
if (!AreCastCompatible(aType, bType))
|
||||
if (!AreCastCompatible({aType, bType}))
|
||||
return op.emitError(llvm::formatv(
|
||||
"{0} type {1} is incompatible with {2} type {3} at index {4}",
|
||||
a.first, aType, b.first, bType, idx));
|
||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"
|
||||
|
@ -905,5 +905,29 @@ def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> {
|
||||
TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>;
|
||||
}
|
||||
|
||||
// TODO(b/156507832): Move tf.InplaceUpdate to tf_generated_ops.td once
|
||||
// autogenerated op def matches.
|
||||
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
|
||||
let summary = "Updates specified rows 'i' with values 'v'.";
|
||||
|
||||
let description = [{
|
||||
Computes `x[i, :] = v; return x`.
|
||||
|
||||
Originally this function is mutative however for compilation we make this
|
||||
operation create / operate on a copy of `x`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$x,
|
||||
I32Tensor:$i,
|
||||
TF_Tensor:$v
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
#endif // TF_OPS
|
||||
|
@ -28,6 +28,134 @@ llvm::Optional<llvm::ArrayRef<int64_t>> GetShape(mlir::Value value) {
|
||||
if (shaped_type.hasRank()) return shaped_type.getShape();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
// Merges cast compatible shapes and returns a more refined shape. The two
|
||||
// shapes are cast compatible if they have the same rank and at each dimension,
|
||||
// either both have same size or one of them is dynamic. Returns false if the
|
||||
// given shapes are not cast compatible. The refined shape is same or more
|
||||
// precise than the two input shapes.
|
||||
bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,
|
||||
llvm::ArrayRef<int64_t> b_shape,
|
||||
llvm::SmallVectorImpl<int64_t>* refined_shape) {
|
||||
if (a_shape.size() != b_shape.size()) return false;
|
||||
int64_t rank = a_shape.size();
|
||||
refined_shape->reserve(rank);
|
||||
for (auto dims : llvm::zip(a_shape, b_shape)) {
|
||||
int64_t dim1 = std::get<0>(dims);
|
||||
int64_t dim2 = std::get<1>(dims);
|
||||
|
||||
if (mlir::ShapedType::isDynamic(dim1)) {
|
||||
refined_shape->push_back(dim2);
|
||||
continue;
|
||||
}
|
||||
if (mlir::ShapedType::isDynamic(dim2)) {
|
||||
refined_shape->push_back(dim1);
|
||||
continue;
|
||||
}
|
||||
if (dim1 == dim2) {
|
||||
refined_shape->push_back(dim1);
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Given two types `a` and `b`, returns a refined type which is cast compatible
|
||||
// with both `a` and `b` and is equal to or more precise than both of them. It
|
||||
// returns empty Type if the input types are not cast compatible.
|
||||
//
|
||||
// The two types are considered cast compatible if they have dynamically equal
|
||||
// shapes and element type. For element types that do not have subtypes, they
|
||||
// must be equal. However for TensorFlow types such as Resource and Variant,
|
||||
// that also have subtypes, we recursively check for subtype compatibilty for
|
||||
// Resource types and assume all variant types are cast compatible. If either
|
||||
// one of `a` or `b` have empty subtypes, they are considered cast compatible.
|
||||
//
|
||||
// The returned type is same or more precise than the input types. For example,
|
||||
// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
|
||||
// tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
|
||||
//
|
||||
// Provides option to ignore ref types on 'a'. This is useful for TF ops that
|
||||
// might allow operands to either be same as result type or be a ref type
|
||||
// corresponding to it.
|
||||
mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
|
||||
bool may_ignore_ref_type_a) {
|
||||
// Fast path if everything is equal.
|
||||
if (a == b) return b;
|
||||
|
||||
auto a_tt = a.dyn_cast<mlir::TensorType>();
|
||||
auto b_tt = b.dyn_cast<mlir::TensorType>();
|
||||
|
||||
// If only one of a or b is a tensor type, they are incompatible.
|
||||
if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
|
||||
|
||||
// For non-tensor types, we do not need to worry about shape and can return
|
||||
// early.
|
||||
if (!a_tt && !b_tt) {
|
||||
// Remove ref types.
|
||||
if (may_ignore_ref_type_a) {
|
||||
if (auto ref_type = a.dyn_cast<mlir::TF::TensorFlowRefType>()) {
|
||||
a = ref_type.RemoveRef();
|
||||
if (a == b) return a;
|
||||
}
|
||||
}
|
||||
if (a.getKind() != b.getKind()) return nullptr;
|
||||
|
||||
// If either is not a type that contain subtypes then the types are not cast
|
||||
// compatible.
|
||||
auto a_wst = a.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
|
||||
auto b_wst = b.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
|
||||
if (!a_wst || !b_wst) return nullptr;
|
||||
|
||||
// For Variant types we are more permissive right now and accept all pairs
|
||||
// of Variant types. If we are more constrainted and check compatibility of
|
||||
// subtypes, we might reject valid graphs.
|
||||
// TODO(prakalps): Variant doesn't have a subtype, we assign it
|
||||
// one, so we should only assign it one when we know the subtype. Then we
|
||||
// can be more constrained and check subtypes for cast compatibility as
|
||||
// well.
|
||||
if (a.isa<mlir::TF::VariantType>()) return a;
|
||||
|
||||
// For Resource types, we recursively check the subtypes for cast
|
||||
// compatibility, if possible. Otherwise treat them as compatible.
|
||||
auto a_wst_st = a_wst.GetSubtypes();
|
||||
auto b_wst_st = b_wst.GetSubtypes();
|
||||
if (a_wst_st.empty() || b_wst_st.empty()) return a;
|
||||
if (a_wst_st.size() != b_wst_st.size()) return nullptr;
|
||||
llvm::SmallVector<mlir::TensorType, 4> refined_subtypes;
|
||||
for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) {
|
||||
mlir::Type refined_st =
|
||||
GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
|
||||
/*may_ignore_ref_type_a=*/false);
|
||||
if (!refined_st) return nullptr;
|
||||
refined_subtypes.push_back(refined_st.cast<mlir::TensorType>());
|
||||
}
|
||||
|
||||
return mlir::TF::ResourceType::get(refined_subtypes, a.getContext());
|
||||
}
|
||||
|
||||
// For tensor types, check compatibility of both element type and shape.
|
||||
mlir::Type refined_element_ty = GetCastCompatibleType(
|
||||
a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
|
||||
if (!refined_element_ty) return nullptr;
|
||||
|
||||
if (!a_tt.hasRank() && !b_tt.hasRank()) {
|
||||
return mlir::UnrankedTensorType::get(refined_element_ty);
|
||||
}
|
||||
if (!a_tt.hasRank()) {
|
||||
return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty);
|
||||
}
|
||||
if (!b_tt.hasRank()) {
|
||||
return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty);
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t, 8> refined_shape;
|
||||
if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
|
||||
return nullptr;
|
||||
|
||||
return mlir::RankedTensorType::get(refined_shape, refined_element_ty);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
@ -224,44 +352,16 @@ bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
|
||||
|
||||
bool HasCompatibleElementTypes(Type lhs, Type rhs,
|
||||
bool may_ignore_ref_type_lhs) {
|
||||
// Fast path if everything is equal.
|
||||
if (lhs == rhs) return true;
|
||||
return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
|
||||
}
|
||||
|
||||
// In TF all values are tensors.
|
||||
auto lhs_tt = lhs.cast<TensorType>();
|
||||
auto rhs_tt = rhs.cast<TensorType>();
|
||||
|
||||
// Verify matching element types. These should be identical dynamically,
|
||||
// so this allows for types not yet fully refined.
|
||||
auto lhs_et = lhs_tt.getElementType();
|
||||
auto rhs_et = rhs_tt.getElementType();
|
||||
if (lhs_et == rhs_et) return true;
|
||||
|
||||
// Remove ref types.
|
||||
if (may_ignore_ref_type_lhs) {
|
||||
if (auto ref_type = lhs_et.dyn_cast<TF::TensorFlowRefType>()) {
|
||||
lhs_et = ref_type.RemoveRef();
|
||||
if (lhs_et == rhs_et) return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (lhs_et.getKind() != rhs_et.getKind()) return false;
|
||||
|
||||
// If either is not type that contain subtypes then the element types don't
|
||||
// match.
|
||||
auto lhs_wst = lhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
||||
auto rhs_wst = rhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
||||
if (!lhs_wst || !rhs_wst) return false;
|
||||
|
||||
// Consider the subtype recursively.
|
||||
auto lhs_wst_st = lhs_wst.GetSubtypes();
|
||||
auto rhs_wst_st = rhs_wst.GetSubtypes();
|
||||
if (lhs_wst_st.empty() || rhs_wst_st.empty()) return true;
|
||||
if (lhs_wst_st.size() != rhs_wst_st.size()) return false;
|
||||
for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) {
|
||||
if (!HasCompatibleElementTypes(std::get<0>(subtypes),
|
||||
std::get<1>(subtypes)))
|
||||
return false;
|
||||
bool AreCastCompatible(ArrayRef<Type> types) {
|
||||
Type common = types.front();
|
||||
for (auto type : types.drop_front()) {
|
||||
Type refined_type =
|
||||
GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
|
||||
if (!refined_type) return false;
|
||||
common = refined_type;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -313,6 +313,12 @@ bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs);
|
||||
bool HasCompatibleElementTypes(Type lhs, Type rhs,
|
||||
bool may_ignore_ref_type_lhs = false);
|
||||
|
||||
// Returns true if all TensorFlow types can be cast to one
|
||||
// another. In other words, a single run-time value is legal for both the types.
|
||||
// For example, tensor<*xf32>, tensor<?xf32> and tensor<3xf32> are cast
|
||||
// compatible.
|
||||
bool AreCastCompatible(ArrayRef<Type> types);
|
||||
|
||||
} // end namespace TF
|
||||
} // end namespace mlir
|
||||
|
||||
|
@ -471,3 +471,14 @@ func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor<i32> {
|
||||
// CHECK: return [[VAL0]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @foldFill
|
||||
func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>) {
|
||||
%0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%1 = "tf.Const"() {value = dense<23.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>}
|
||||
%2 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor<f32>) -> tensor<3x2x1xf32>
|
||||
// CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>}
|
||||
%3 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor<f32>) -> tensor<*xf32>
|
||||
return %2, %3 : tensor<3x2x1xf32>, tensor<*xf32>
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user