Merge branch 'master' into 12829-extract_glimpse

This commit is contained in:
Mihai Maruseac 2020-05-19 15:35:19 +00:00 committed by GitHub
commit 158d128323
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
956 changed files with 40258 additions and 15186 deletions

View File

@ -143,6 +143,11 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_mkl_dnn_v1_only=true build:mkl --define=build_with_mkl_dnn_v1_only=true
build:mkl -c opt build:mkl -c opt
# config to build OneDNN backend with a user specified threadpool.
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
build:mkl_threadpool -c opt
# This config refers to building with CUDA available. It does not necessarily # This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels. # mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true build:using_cuda --define=using_cuda=true
@ -235,10 +240,15 @@ build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++ build:c++17 --cxxopt=-stdlib=libc++
build:c++1z --config=c++17 build:c++1z --config=c++17
# Enable using platform specific build settings # Enable using platform specific build settings, except when cross-compiling for
# mobile platforms.
build --enable_platform_specific_config build --enable_platform_specific_config
build:android --noenable_platform_specific_config
build:ios --noenable_platform_specific_config
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs. # Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
build:android --copt=-w
build:ios --copt=-w
build:linux --copt=-w build:linux --copt=-w
build:macos --copt=-w build:macos --copt=-w
build:windows --copt=/w build:windows --copt=/w
@ -258,6 +268,10 @@ build:macos --define=INCLUDEDIR=$(PREFIX)/include
# TF_SYSTEM_LIBS do not work on windows. # TF_SYSTEM_LIBS do not work on windows.
# By default, build TF in C++ 14 mode. # By default, build TF in C++ 14 mode.
build:android --cxxopt=-std=c++14
build:android --host_cxxopt=-std=c++14
build:ios --cxxopt=-std=c++14
build:ios --host_cxxopt=-std=c++14
build:linux --cxxopt=-std=c++14 build:linux --cxxopt=-std=c++14
build:linux --host_cxxopt=-std=c++14 build:linux --host_cxxopt=-std=c++14
build:macos --cxxopt=-std=c++14 build:macos --cxxopt=-std=c++14

29
.github/bot_config.yml vendored Normal file
View File

@ -0,0 +1,29 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#
# THIS IS A GENERATED DOCKERFILE.
#
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
# for more information.
# A list of assignees
assignees:
- amahendrakar
- ravikyram
- Saduf2019
# A list of assignees for
compiler_assignees:
- joker-eph

View File

@ -103,17 +103,17 @@ open-source software development:
### Official Builds ### Official Builds
Build Type | Status | Artifacts Build Type | Status | Artifacts
------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- ------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) **Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) **Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA **Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) **macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) **Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) **Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) **Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) **Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) **Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
### Community Supported Builds ### Community Supported Builds

View File

@ -3,6 +3,31 @@
## Breaking Changes ## Breaking Changes
* `tf.image.extract_glimpse` has been updated to correctly process the case where `centered=False` and `normalized=False`. This is a breaking change as the output is different from (incorrect) previous versions. Note this breaking change only impacts `tf.image.extract_glimpse` and `tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of `tf.compat.v1.image.extract_glimpse` does not change. The behavior of exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved models will not be impacted. * `tf.image.extract_glimpse` has been updated to correctly process the case where `centered=False` and `normalized=False`. This is a breaking change as the output is different from (incorrect) previous versions. Note this breaking change only impacts `tf.image.extract_glimpse` and `tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of `tf.compat.v1.image.extract_glimpse` does not change. The behavior of exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved models will not be impacted.
# Release 2.1.1
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
* Fixes a versioning bug which causes Keras layers from TF 1.x to be used instead of those from TF 2.x
# Release 2.0.2
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
# Release 1.15.3
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
# Release 2.2.0 # Release 2.2.0
TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update).

View File

@ -64,7 +64,7 @@ your model, and we recommend you run the TensorFlow process in a sandbox.
It is possible to write models that are secure in a sense that they can safely It is possible to write models that are secure in a sense that they can safely
process untrusted inputs assuming there are no bugs. There are two main reasons process untrusted inputs assuming there are no bugs. There are two main reasons
to not rely on this: first, it is easy to write models which must not be exposed to not rely on this: First, it is easy to write models which must not be exposed
to untrusted inputs, and second, there are bugs in any software system of to untrusted inputs, and second, there are bugs in any software system of
sufficient complexity. Letting users control inputs could allow them to trigger sufficient complexity. Letting users control inputs could allow them to trigger
bugs either in TensorFlow or in dependent libraries. bugs either in TensorFlow or in dependent libraries.
@ -149,7 +149,7 @@ attack (or worse). Because TensorFlow behaves correctly, this is not a
vulnerability in TensorFlow (although it would be a vulnerability of this vulnerability in TensorFlow (although it would be a vulnerability of this
hypothetical system). hypothetical system).
As a general rule, it is incorrect behavior for Tensorflow to access memory it As a general rule, it is incorrect behavior for TensorFlow to access memory it
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
such behaviors constitute a vulnerability. such behaviors constitute a vulnerability.

View File

@ -144,7 +144,7 @@ def write_to_bazelrc(line):
def write_action_env_to_bazelrc(var_name, var): def write_action_env_to_bazelrc(var_name, var):
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var))) write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var)))
def run_shell(cmd, allow_non_zero=False, stderr=None): def run_shell(cmd, allow_non_zero=False, stderr=None):
@ -205,7 +205,7 @@ def setup_python(environ_cp):
# Get PYTHON_BIN_PATH, default is the current running python. # Get PYTHON_BIN_PATH, default is the current running python.
default_python_bin_path = sys.executable default_python_bin_path = sys.executable
ask_python_bin_path = ('Please specify the location of python. [Default is ' ask_python_bin_path = ('Please specify the location of python. [Default is '
'%s]: ') % default_python_bin_path '{}]: ').format(default_python_bin_path)
while True: while True:
python_bin_path = get_from_env_or_user_or_default(environ_cp, python_bin_path = get_from_env_or_user_or_default(environ_cp,
'PYTHON_BIN_PATH', 'PYTHON_BIN_PATH',
@ -215,9 +215,10 @@ def setup_python(environ_cp):
if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK): if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
break break
elif not os.path.exists(python_bin_path): elif not os.path.exists(python_bin_path):
print('Invalid python path: %s cannot be found.' % python_bin_path) print('Invalid python path: {} cannot be found.'.format(python_bin_path))
else: else:
print('%s is not executable. Is it the python binary?' % python_bin_path) print('{} is not executable. Is it the python binary?'.format(
python_bin_path))
environ_cp['PYTHON_BIN_PATH'] = '' environ_cp['PYTHON_BIN_PATH'] = ''
# Convert python path to Windows style before checking lib and version # Convert python path to Windows style before checking lib and version
@ -236,7 +237,7 @@ def setup_python(environ_cp):
default_python_lib_path = python_lib_paths[0] default_python_lib_path = python_lib_paths[0]
python_lib_path = get_input( python_lib_path = get_input(
'Please input the desired Python library path to use. ' 'Please input the desired Python library path to use. '
'Default is [%s]\n' % python_lib_paths[0]) 'Default is [{}]\n'.format(python_lib_paths[0]))
if not python_lib_path: if not python_lib_path:
python_lib_path = default_python_lib_path python_lib_path = default_python_lib_path
environ_cp['PYTHON_LIB_PATH'] = python_lib_path environ_cp['PYTHON_LIB_PATH'] = python_lib_path
@ -252,7 +253,7 @@ def setup_python(environ_cp):
# Set-up env variables used by python_configure.bzl # Set-up env variables used by python_configure.bzl
write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path))
environ_cp['PYTHON_BIN_PATH'] = python_bin_path environ_cp['PYTHON_BIN_PATH'] = python_bin_path
# If choosen python_lib_path is from a path specified in the PYTHONPATH # If choosen python_lib_path is from a path specified in the PYTHONPATH
@ -266,7 +267,7 @@ def setup_python(environ_cp):
with open( with open(
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
'w') as f: 'w') as f:
f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) f.write('export PYTHON_BIN_PATH="{}"'.format(python_bin_path))
def reset_tf_configure_bazelrc(): def reset_tf_configure_bazelrc():
@ -320,11 +321,12 @@ def get_var(environ_cp,
Raise the error to avoid infinitely looping. Raise the error to avoid infinitely looping.
""" """
if not question: if not question:
question = 'Do you wish to build TensorFlow with %s support?' % query_item question = 'Do you wish to build TensorFlow with {} support?'.format(
query_item)
if not yes_reply: if not yes_reply:
yes_reply = '%s support will be enabled for TensorFlow.' % query_item yes_reply = '{} support will be enabled for TensorFlow.'.format(query_item)
if not no_reply: if not no_reply:
no_reply = 'No %s' % yes_reply no_reply = 'No {}'.format(yes_reply)
yes_reply += '\n' yes_reply += '\n'
no_reply += '\n' no_reply += '\n'
@ -368,7 +370,7 @@ def get_var(environ_cp,
print(no_reply) print(no_reply)
var = False var = False
else: else:
print('Invalid selection: %s' % user_input_origin) print('Invalid selection: {}'.format(user_input_origin))
return var return var
@ -1385,7 +1387,6 @@ def main():
# Windows. # Windows.
environ_cp['TF_DOWNLOAD_CLANG'] = '0' environ_cp['TF_DOWNLOAD_CLANG'] = '0'
environ_cp['TF_NEED_MPI'] = '0' environ_cp['TF_NEED_MPI'] = '0'
environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0'
if is_macos(): if is_macos():
environ_cp['TF_NEED_TENSORRT'] = '0' environ_cp['TF_NEED_TENSORRT'] = '0'

View File

@ -530,6 +530,13 @@ package_group(name = "ndarray_tensor_allow_list")
# TODO(b/154762408) Remove this package group once it's no longer needed. # TODO(b/154762408) Remove this package group once it's no longer needed.
package_group(name = "composite_tensor_whitelist") package_group(name = "composite_tensor_whitelist")
# Packages that use private types symbols, until they are exported.
# TODO(b/154650521) Remove.
package_group(
name = "types_whitelist",
packages = ["//learning/deepmind/tensorflow/replicator/..."],
)
filegroup( filegroup(
name = "intel_binary_blob", name = "intel_binary_blob",
data = if_mkl_ml( data = if_mkl_ml(

View File

@ -85,7 +85,7 @@ tf_cuda_library(
], ],
deps = select({ deps = select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite",
], ],
"//tensorflow:chromiumos": [ "//tensorflow:chromiumos": [
":tf_attrtype", ":tf_attrtype",
@ -182,7 +182,7 @@ tf_cuda_library(
":tf_status_internal", ":tf_status_internal",
] + select({ ] + select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
":tf_status", ":tf_status",
@ -219,7 +219,7 @@ tf_cuda_library(
], ],
deps = select({ deps = select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
], ],
"//conditions:default": [ "//conditions:default": [
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -232,12 +232,13 @@ cc_library(
srcs = ["tf_status.cc"], srcs = ["tf_status.cc"],
hdrs = ["tf_status.h"], hdrs = ["tf_status.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = select({ deps = [
":tf_status_internal",
] + select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
], ],
"//conditions:default": [ "//conditions:default": [
":tf_status_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
}), }),
@ -259,10 +260,15 @@ cc_library(
name = "tensor_interface", name = "tensor_interface",
hdrs = ["tensor_interface.h"], hdrs = ["tensor_interface.h"],
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = select({
"//tensorflow/core:lib", "//tensorflow:android": [
"//tensorflow/core:protos_all_cc", "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
], ],
"//conditions:default": [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
}),
) )
cc_library( cc_library(
@ -272,7 +278,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = select({ deps = select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
], ],
"//conditions:default": [ "//conditions:default": [
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -286,16 +292,17 @@ cc_library(
srcs = ["tf_tensor.cc"], srcs = ["tf_tensor.cc"],
hdrs = ["tf_tensor.h"], hdrs = ["tf_tensor.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = select({ deps = [
":tensor_interface",
":tf_datatype",
":tf_status",
":tf_status_helper",
":tf_tensor_internal",
] + select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
], ],
"//conditions:default": [ "//conditions:default": [
":tensor_interface",
":tf_datatype",
":tf_status",
":tf_status_helper",
":tf_tensor_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
@ -311,14 +318,15 @@ tf_cuda_library(
"tf_tensor_internal.h", "tf_tensor_internal.h",
], ],
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = select({ deps = [
":tensor_interface",
":tf_datatype",
":tf_status",
] + select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
], ],
"//conditions:default": [ "//conditions:default": [
":tensor_interface",
":tf_datatype",
":tf_status",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:casts", "//tensorflow/core/platform:casts",
@ -386,8 +394,14 @@ tf_cuda_library(
deps = [ deps = [
":tf_status", ":tf_status",
":tf_status_internal", ":tf_status_internal",
"//tensorflow/core:lib", ] + select({
], "//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
],
}),
) )
tf_cc_test( tf_cc_test(
@ -426,7 +440,7 @@ tf_cuda_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = select({ deps = select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -457,7 +471,7 @@ tf_cuda_library(
] + select({ ] + select({
"//tensorflow:android": [ "//tensorflow:android": [
":c_api_internal", ":c_api_internal",
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
":c_api_internal", ":c_api_internal",
@ -484,7 +498,7 @@ tf_cuda_library(
":tf_status_helper", ":tf_status_helper",
] + select({ ] + select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
"//tensorflow/core:framework", "//tensorflow/core:framework",

View File

@ -35,7 +35,7 @@ tf_cuda_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = select({ deps = select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
":context_interface", ":context_interface",
@ -319,6 +319,7 @@ tf_cuda_cc_test(
tags = [ tags = [
"noguitar", # TODO(b/155445984): flaky "noguitar", # TODO(b/155445984): flaky
#"guitar", #"guitar",
"notap", # TODO(b/156981931): flaky
"multi_gpu", "multi_gpu",
], ],
deps = [ deps = [
@ -357,10 +358,13 @@ tf_cuda_cc_test(
":c_api_test_util", ":c_api_test_util",
":tfe_tensorhandle_internal", ":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util", "//tensorflow/c:c_test_util",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:function_optimization_registry",
"//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
@ -412,7 +416,7 @@ tf_cuda_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = select({ deps = select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
":c_api", ":c_api",
@ -448,6 +452,8 @@ tf_cuda_library(
"//conditions:default": [], "//conditions:default": [],
}) + [ }) + [
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_helper",
"//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",

View File

@ -899,9 +899,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
#if defined(IS_MOBILE_PLATFORM) #if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::Status::OK(); status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM) #else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context = status->status = tensorflow::unwrap(ctx)->AsyncWait();
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->SyncExecutors();
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
} }
@ -924,7 +922,7 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
context->GetDevicePlacementPolicy()); context->GetDevicePlacementPolicy());
} }
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor; tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor); status->status = tensorflow::TF_TensorToTensor(t, &tensor);
if (!status->status.ok()) return nullptr; if (!status->status.ok()) return nullptr;

View File

@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
// placed in memory of different devices or remote address spaces. // placed in memory of different devices or remote address spaces.
typedef struct TFE_TensorHandle TFE_TensorHandle; typedef struct TFE_TensorHandle TFE_TensorHandle;
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
TF_Status* status); TF_Status* status);
// Indicates that the caller will not be using `h` any more. // Indicates that the caller will not be using `h` any more.
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);

View File

@ -50,6 +50,13 @@ tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks); return GetServerDef("localhost", num_tasks);
} }
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->at(task_index) =
tensorflow::strings::StrCat("localhost:", port);
}
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
const std::vector<float>& expected_values) { const std::vector<float>& expected_values) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
@ -101,6 +108,22 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
TF_DeleteStatus(status); TF_DeleteStatus(status);
} }
// Read the value of variable `var` and save it into `out_value`.
void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var,
TFE_TensorHandle** out_value) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_retvals = 1;
TFE_Execute(op, out_value, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(op);
TF_DeleteStatus(status);
}
void TestRemoteExecuteChangeServerDef(bool async) { void TestRemoteExecuteChangeServerDef(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2); tensorflow::ServerDef server_def = GetServerDef(2);
@ -243,6 +266,102 @@ TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
TestRemoteExecuteUpdateServerDef(true); TestRemoteExecuteUpdateServerDef(true);
} }
void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
EXPECT_NE(var_handle0, nullptr);
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
EXPECT_NE(var_handle1, nullptr);
TFE_TensorHandle* value_handle = nullptr;
ReadVariable(ctx, var_handle1, &value_handle);
CheckTFE_TensorHandleHasFloats(value_handle, {2});
TFE_DeleteTensorHandle(value_handle);
// Start a new worker to replace task:1
ReplaceTaskInServerDef(&server_def, 1);
server_def.set_task_index(1);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
// Update server def to replace the remote device with the device info on the
// new worker (different incarnation ID).
server_def.set_task_index(0);
string serialized_update = server_def.SerializeAsString();
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
serialized_update.size(), status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// The device of var_handle0 is local device which is the same before and
// after cluster update. Remove resource with valid device should succeed.
TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, var_handle0, status);
TFE_OpSetDevice(op, dev0_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_retvals = 0;
TFE_Execute(op, nullptr, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(op);
// The device of var_handle1 is remote device, which was replaced during
// cluster update. Removing resource with invalid device should fail
// gracefully (i.e., with error status) instead of crashing with segfaults.
op = TFE_NewOp(ctx, "DestroyResourceOp", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, var_handle1, status);
TFE_OpSetDevice(op, dev1_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
num_retvals = 0;
TFE_Execute(op, nullptr, &num_retvals, status);
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(op);
TFE_DeleteTensorHandle(var_handle0);
TFE_DeleteTensorHandle(var_handle1);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) {
TestRemoteExecuteUpdateServerDefResourceAccess(false);
}
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) {
TestRemoteExecuteUpdateServerDefResourceAccess(true);
}
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) { void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
// Fail fast on GetStatus requests so we can get errors instead of timeout // Fail fast on GetStatus requests so we can get errors instead of timeout
// when updating cluster with non-exsitent worker // when updating cluster with non-exsitent worker
@ -282,6 +401,7 @@ void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
int port = tensorflow::testing::PickUnusedPortOrDie(); int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert( job_def->mutable_tasks()->insert(
{2, tensorflow::strings::StrCat("localhost:", port)}); {2, tensorflow::strings::StrCat("localhost:", port)});
server_def.set_task_index(0);
string serialized_update = server_def.SerializeAsString(); string serialized_update = server_def.SerializeAsString();
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(), TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
serialized_update.size(), status); serialized_update.size(), status);

View File

@ -657,3 +657,17 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
std::move(tensor_handles), context, &handle); std::move(tensor_handles), context, &handle);
return tensorflow::wrap(handle); return tensorflow::wrap(handle);
} }
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetAllowSoftPlacement(enable);
}
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetLogDevicePlacement(enable);
}

View File

@ -549,6 +549,18 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles, TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
TF_Status* status); TF_Status* status);
// Configure soft device placement policy for the eager executor. Note this
// policy is applied to any subsequent op executions.
TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
unsigned char enable,
TF_Status* status);
// Configure device placement policy logging for the eager executor. Note this
// policy is applied to any subsequent op executions.
TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
unsigned char enable,
TF_Status* status);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -19,11 +19,16 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace { namespace {
@ -434,7 +439,26 @@ string AddVariablesFunction() {
return def.SerializeAsString(); return def.SerializeAsString();
} }
TEST(CAPI, TestFunctionWithPackedInput) { void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(op, var_handle, status);
TFE_TensorHandle* is_initialized[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
CHECK_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
bool initialized = false;
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
EXPECT_EQ(initialized, true);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(is_initialized[0]);
TFE_DeleteOp(op);
delete status;
}
void TestFunctionWithPackedInput(const bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3); tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0. // This server def has the task index set to 0.
@ -474,6 +498,12 @@ TEST(CAPI, TestFunctionWithPackedInput) {
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name); TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name); TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
// Add a sync point in order to make sure that variables have been initialized
// before the function execution starts.
// TODO(b/155789951): Remove once b/155789951 is fixed.
VarIsInitialized(ctx, h1);
VarIsInitialized(ctx, h2);
// Pack 3 variable handles into one TFE_TensorHandle. // Pack 3 variable handles into one TFE_TensorHandle.
int num_replicas = 3; int num_replicas = 3;
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2}; std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
@ -502,6 +532,10 @@ TEST(CAPI, TestFunctionWithPackedInput) {
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, packed_handle, status); TFE_OpAddInput(func, packed_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
if (remote) {
TFE_OpSetDevice(func, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
TFE_TensorHandle* retvals[1] = {nullptr}; TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1; int num_retvals = 1;
@ -537,6 +571,189 @@ TEST(CAPI, TestFunctionWithPackedInput) {
worker_server2.release(); worker_server2.release();
} }
TEST(CAPI, TestLocalFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/false);
}
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/true);
}
string VariableAddFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'VariableAddFunction'"
" input_arg {"
" name: 'var0'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'var0_value'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read0:value:0'"
" device: '/job:localhost/task:1/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'identity'"
" op: 'Identity'"
" input: 'add:z:0'"
" device: '/job:localhost/task:0/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'var0_value'"
" value: 'identity:output:0'"
" }",
&def));
return def.SerializeAsString();
}
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
public:
FunctionErrorInjectionPass(string error_node, string error_device)
: error_node_(error_node), error_device_(error_device) {}
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
const tensorflow::ConfigProto& config_proto,
std::unique_ptr<tensorflow::Graph>* graph,
tensorflow::FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override {
// Inject failure to function instantiation if finding a node that contains
// the given node name (error_node_) and requested device (error_device_).
for (const auto node : graph->get()->nodes()) {
if (node->name().find(error_node_) != string::npos &&
node->requested_device() == error_device_) {
return tensorflow::errors::Internal("Injected graph pass error.");
}
}
return tensorflow::Status::OK();
}
private:
const string error_node_;
const string error_device_;
};
void TestDistributedFunctionCancellation(bool inject_error) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
if (inject_error) {
// Inject a function optimization pass failure when it sees the 'read0' op
// having a requested device `dev2_name`. During execution:
// * task:0 processes the main function `VariableAddFunction` and places
// the read0 op on task:2
// * task:0 partitions the main function with a subgraph containing read0
// sent to task:2
// * task:2 graph pass reports an error when it sees read0 with dev2_name
tensorflow::function_optimization_registration::
FunctionOptimizationPassRegistration register_test_pass(
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
}
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
const string function_def = VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, var_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
if (inject_error) {
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
} else {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
ASSERT_EQ(sum, 4.0);
}
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, DistributedFunctionNoError) {
TestDistributedFunctionCancellation(false);
}
TEST(CAPI, DistributedFunctionCancelledOnError) {
TestDistributedFunctionCancellation(true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) { void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2); tensorflow::ServerDef server_def = GetServerDef(2);

View File

@ -1203,6 +1203,8 @@ void BM_ReadVariable(int iters) {
CHECK_EQ(0, TFE_TensorHandleNumDims(h, status)); CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
h = nullptr; h = nullptr;
TFE_OpAddInput(op, var_handle, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
} }
tensorflow::testing::StopTiming(); tensorflow::testing::StopTiming();
TFE_DeleteOp(op); TFE_DeleteOp(op);

View File

@ -150,6 +150,7 @@ TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
TFE_TensorHandle* var_handle = nullptr; TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1; int num_retvals = 1;
TFE_Execute(op, &var_handle, &num_retvals, status); TFE_Execute(op, &var_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_DeleteOp(op); TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return nullptr; if (TF_GetCode(status) != TF_OK) return nullptr;
CHECK_EQ(1, num_retvals); CHECK_EQ(1, num_retvals);

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
@ -26,6 +28,51 @@ using tensorflow::string;
using tensorflow::internal::OutputList; using tensorflow::internal::OutputList;
using tensorflow::internal::unwrap; using tensorflow::internal::unwrap;
namespace tensorflow {
namespace internal {
typedef absl::flat_hash_map<std::string, FactoryFunction> FactoriesMap;
static FactoriesMap& GetFactories() {
static FactoriesMap* factories = new FactoriesMap;
return *factories;
}
static const char* default_factory = "<unset>";
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
assert((!GetFactories().count(name)) ||
(GetFactories()[name] == factory) &&
"Duplicate tracing factory registration");
GetFactories()[name] = factory;
}
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
auto entry = GetFactories().find(default_factory);
if (entry != GetFactories().end()) return entry->second(fn_name, s);
string msg = absl::StrCat(
"No tracing engine factory has been registered with the key '",
default_factory, "' (available: ");
// Ensure deterministic (sorted) order in the error message
std::set<string> factories_sorted;
for (const auto& factory : GetFactories())
factories_sorted.insert(factory.first);
const char* comma = "";
for (const string& factory : factories_sorted) {
msg += comma + factory;
comma = ", ";
}
msg += ")";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
} // end namespace internal
} // end namespace tensorflow
// ============================================================================= // =============================================================================
// Public C API entry points // Public C API entry points
// //
@ -36,6 +83,28 @@ using tensorflow::internal::unwrap;
// //
// ============================================================================= // =============================================================================
void TF_SetTracingImplementation(const char* name) {
tensorflow::internal::SetDefaultTracingEngine(name);
}
// Creates a new TensorFlow function, it is an execution context attached to a
// given tracing context.
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s));
}
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
TF_OutputList* outputs, TF_Status* s) {
auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s));
TF_DeleteExecutionContext(ctx);
return func;
}
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Status* s) {
return wrap(unwrap(func)->AddParameter(dtype, s));
}
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); } void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) { TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
@ -58,6 +127,10 @@ int TF_OutputListNumOutputs(TF_OutputList* o) {
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) { TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
return wrap(unwrap(o)->outputs[i]); return wrap(unwrap(o)->outputs[i]);
} }
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
TF_Status* s) {
unwrap(o)->outputs.push_back(unwrap(tensor));
}
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type, void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s) { TF_Status* s) {

View File

@ -49,15 +49,26 @@ typedef struct TF_AbstractOp TF_AbstractOp;
// setting functional attributes of other composite ops e.g. control flow. // setting functional attributes of other composite ops e.g. control flow.
typedef struct TF_AbstractFunction TF_AbstractFunction; typedef struct TF_AbstractFunction TF_AbstractFunction;
// Creates a context for tracing the execution of operations into a function. // This allows the client to swap the implementation of the tracing engine.
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s); // Any future call to TF_CreateFunction will use the implementation defined
// here.
void TF_SetTracingImplementation(const char* name);
// Creates a new TensorFlow function. A Function is an execution context, and as
// such it can trace operations through TF_ExecuteOperation. After completing
// tracing, a function can be obtained by TF_FinalizeFunction.
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status);
// Creates a context for eager execution of operations. // Creates a context for eager execution of operations.
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*, TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
TF_Status* s); TF_Status* s);
void TF_DeleteExecutionContext(TF_ExecutionContext*); void TF_DeleteExecutionContext(TF_ExecutionContext*);
// Add a new parameter to a TensorFlow Function.
// TODO(aminim): what about shape?
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Status* s);
// Create an operation suitable to use with the provided context. The operation // Create an operation suitable to use with the provided context. The operation
// requires its type (e.g. "AddV2") to be set independently. // requires its type (e.g. "AddV2") to be set independently.
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx); TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
@ -77,19 +88,21 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
void TF_DeleteAbstractTensor(TF_AbstractTensor*); void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// TF_OutputList holds the list of TF_AbstractTensor that results from executing // TF_OutputList holds the list of TF_AbstractTensor that results from executing
// an operation. // an operation, or provided to create a function.
// It just lets us not specify the number of outputs of an operation // When executing an operation in an eager context, the expected number of
// beforehand. This forces a memory allocation in the runtime, which is bad, but // outputs must be set beforehand with `TF_OutputListSetNumOutputs`.
// it allows for generic code.
// TODO(aminim): the description above isn't clear with respect to
// TF_OutputListNumOutputs and the current eager implementation which requires
// the number of outputs to be set by the client.
typedef struct TF_OutputList TF_OutputList; typedef struct TF_OutputList TF_OutputList;
TF_OutputList* TF_NewOutputList(); TF_OutputList* TF_NewOutputList();
void TF_DeleteOutputList(TF_OutputList* o); void TF_DeleteOutputList(TF_OutputList* o);
void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*); // Prepare tracing to the expected number of output for an operation.
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*);
// Return the number of outputs in the list.
int TF_OutputListNumOutputs(TF_OutputList* o); int TF_OutputListNumOutputs(TF_OutputList* o);
// Return the `i`th output in the list.
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i); TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
// Append a tensor at the end of the output list, growing its size by one.
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
TF_Status*);
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe // TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
// capture some inputs and then add a node in the graph. The output tensors are // capture some inputs and then add a node in the graph. The output tensors are
@ -100,13 +113,12 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_ExecutionContext* ctx, TF_Status* s); TF_ExecutionContext* ctx, TF_Status* s);
// Creates a new TF_AbstractFunction from the current tracing states in the // Creates a new TF_AbstractFunction from the current tracing states in the
// context. The returned TF_GraphToFunction must be deleted by the client. // context. The provided `ctx` is consumed by this API call and deleted.
// The returned TF_AbstractFunction must be deleted by the client,
// TODO(aminim): clarify the contract on the state of the context after this // TODO(aminim): clarify the contract on the state of the context after this
// call. // call.
TF_AbstractFunction* TF_ExecutionContextToFunction( TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs, TF_OutputList*, TF_Status*);
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status);
void TF_DeleteAbstractFunction(TF_AbstractFunction*); void TF_DeleteAbstractFunction(TF_AbstractFunction*);

View File

@ -123,6 +123,17 @@ class EagerContext : public ExecutionContext {
} }
} }
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Can't add function parameter on an eager context.");
return nullptr;
}
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Can't use finalize function on an eager context.");
return nullptr;
}
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override { void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
auto* func = afunc->GetTfFunction(s); auto* func = afunc->GetTfFunction(s);
if (!func) { if (!func) {

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "absl/strings/str_cat.h"
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental.h"
@ -114,12 +115,14 @@ struct GraphFunction : public AbstractFunction {
static constexpr AbstractFunctionKind kKind = kGraphFunc; static constexpr AbstractFunctionKind kKind = kGraphFunc;
}; };
// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e. // GraphContext wraps a TF_Graph modeling a single function and manages the
// adding them to the graph. // "execution" of operation, i.e. adding them to the function.
class GraphContext : public ExecutionContext { class GraphContext : public ExecutionContext {
public: public:
GraphContext() explicit GraphContext(const char* name)
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {} : ExecutionContext(kKind),
graph_(new TF_Graph(), TF_DeleteGraph),
name_(name) {}
AbstractOp* CreateOperation() override { AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context. // TODO(srbs): Should the lifetime of this op be tied to the context.
@ -136,6 +139,10 @@ class GraphContext : public ExecutionContext {
return; return;
} }
auto* tf_opdesc = graph_op->op_.release(); auto* tf_opdesc = graph_op->op_.release();
if (tf_opdesc == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete.");
return;
}
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]); auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
if (!graph_tensor) { if (!graph_tensor) {
@ -164,24 +171,38 @@ class GraphContext : public ExecutionContext {
} }
} }
TF_Function* ToFunction(const char* fn_name, int num_inputs, AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
const GraphTensor* inputs, int num_outputs, TF_OperationDescription* opdesc =
const GraphTensor* outputs, TF_Status* status) const { TF_NewOperation(graph_.get(), "Placeholder",
std::vector<TF_Output> graph_inputs; absl::StrCat("_input_", inputs_.size()).c_str());
graph_inputs.resize(num_inputs); TF_SetAttrType(opdesc, "dtype", dtype);
auto* operation = TF_FinishOperation(opdesc, s);
if (!s->status.ok()) return nullptr;
inputs_.push_back(TF_Output{operation, 0});
return new GraphTensor(inputs_.back(), this);
}
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
std::unique_ptr<GraphFunction> func(new GraphFunction);
std::vector<TF_Output> graph_outputs; std::vector<TF_Output> graph_outputs;
graph_outputs.resize(num_outputs); graph_outputs.reserve(outputs->outputs.size());
for (int i = 0; i < num_inputs; i++) { for (AbstractTensor* abstract_output : outputs->outputs) {
graph_inputs[i] = inputs[i].output; GraphTensor* output = dyncast<GraphTensor>(abstract_output);
} if (!output) {
for (int i = 0; i < num_outputs; i++) { TF_SetStatus(s, TF_UNIMPLEMENTED,
graph_outputs[i] = outputs[i].output; "Returning a non-graph tensor from a function has not "
"been implemented yet.");
return nullptr;
}
graph_outputs.push_back(output->output);
} }
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr, func->func = TF_GraphToFunction(
graph_inputs.size(), graph_inputs.data(), graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
graph_outputs.size(), graph_outputs.data(), graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
nullptr, nullptr, fn_name, status); if (TF_GetCode(s) != TF_OK) return nullptr;
return func.release();
} }
void RegisterFunction(AbstractFunction* func, TF_Status* s) override { void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
@ -195,54 +216,20 @@ class GraphContext : public ExecutionContext {
private: private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_; std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
std::vector<TF_Output> inputs_;
const char* name_;
}; };
// Helper that converts the graph currently held in the context into a function. static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) {
static AbstractFunction* ExecutionContextToFunction( return new GraphContext(name);
const ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const AbstractTensor* inputs, int num_outputs,
const AbstractTensor* outputs, TF_Status* status) {
auto* graph_ctx = dyncast<const GraphContext>(fn_body);
if (graph_ctx == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"fn_body is not a TF_GraphContext.");
return nullptr;
}
auto* graph_inputs = dyncast<const GraphTensor>(inputs);
if (!graph_inputs) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
return nullptr;
}
auto* graph_outputs = dyncast<const GraphTensor>(outputs);
if (!graph_outputs) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
return nullptr;
}
GraphFunction* func = new GraphFunction;
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
num_outputs, graph_outputs, status);
return func;
} }
// Register the tracing implemented in this file as the default tracing engine.
static bool register_tracing = [] {
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
SetDefaultTracingEngine("graphdef");
return true;
}();
} // namespace internal } // namespace internal
} // namespace tensorflow } // namespace tensorflow
// =============================================================================
// Public C API entry points
// These are only the entry points specific to the Graph API.
// =============================================================================
using tensorflow::internal::unwrap;
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
return wrap(new tensorflow::internal::GraphContext());
}
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status) {
return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs,
unwrap(inputs), num_outputs,
unwrap(outputs), status));
}

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow { namespace tensorflow {
namespace internal { namespace internal {
@ -148,6 +149,17 @@ struct ExecutionContext {
// Creates an empty AbstractOperation suitable to use with this context. // Creates an empty AbstractOperation suitable to use with this context.
virtual AbstractOp* CreateOperation() = 0; virtual AbstractOp* CreateOperation() = 0;
// Add a function parameter and return the corresponding tensor.
// This is only valid with an ExecutionContext obtained from a TracingContext,
// it'll always error out with an eager context.
virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0;
// Finalize this context and make a function out of it. The context is in a
// invalid state after this call and must be destroyed.
// This is only valid with an ExecutionContext obtained from a TracingContext,
// it'll always error out with an eager context.
virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0;
// Registers a functions with this context, after this the function is // Registers a functions with this context, after this the function is
// available to be called/referenced by its name in this context. // available to be called/referenced by its name in this context.
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0; virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
@ -156,6 +168,11 @@ struct ExecutionContext {
const ExecutionContextKind k; const ExecutionContextKind k;
}; };
typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
void SetDefaultTracingEngine(const char* name);
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
FactoryFunction factory);
// Create utilities to wrap/unwrap: this convert from the C opaque types to the // Create utilities to wrap/unwrap: this convert from the C opaque types to the
// C++ implementation, and back. // C++ implementation, and back.
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \ #define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \

View File

@ -29,7 +29,12 @@ using tensorflow::string;
namespace tensorflow { namespace tensorflow {
namespace { namespace {
TEST(UnifiedCAPI, TestBasicEager) { class UnifiedCAPI : public ::testing::TestWithParam<const char*> {
protected:
void SetUp() override { TF_SetTracingImplementation(GetParam()); }
};
TEST_P(UnifiedCAPI, TestBasicEager) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -81,33 +86,18 @@ TEST(UnifiedCAPI, TestBasicEager) {
TF_DeleteExecutionContext(ctx); TF_DeleteExecutionContext(ctx);
} }
TEST(UnifiedCAPI, TestBasicGraph) { TEST_P(UnifiedCAPI, TestBasicGraph) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); // Start a new function / execution context.
string fn_name = "double";
TF_ExecutionContext* graph_ctx =
TF_CreateFunction(fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph. auto* placeholder_t =
auto* placeholder_op = TF_NewAbstractOp(graph_ctx); TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_OutputList* placeholder_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
// Delete placeholder op.
TF_DeleteAbstractOp(placeholder_op);
// Build an abstract operation. // Build an abstract operation.
auto* add_op = TF_NewAbstractOp(graph_ctx); auto* add_op = TF_NewAbstractOp(graph_ctx);
@ -123,17 +113,13 @@ TEST(UnifiedCAPI, TestBasicGraph) {
// Execute. // Execute.
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get()); TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0);
// Clean up operation and inputs. // Clean up operation and inputs.
TF_DeleteAbstractOp(add_op); TF_DeleteAbstractOp(add_op);
string fn_name = "double"; TF_AbstractFunction* func =
TF_AbstractFunction* func = TF_ExecutionContextToFunction( TF_FinalizeFunction(graph_ctx, add_outputs, status.get());
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_DeleteAbstractTensor(placeholder_t);
TF_DeleteAbstractTensor(output_t);
// Build eager context. // Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -174,17 +160,160 @@ TEST(UnifiedCAPI, TestBasicGraph) {
ASSERT_EQ(*f_value, 4.0); ASSERT_EQ(*f_value, 4.0);
TF_DeleteOutputList(add_outputs); TF_DeleteOutputList(add_outputs);
TF_DeleteOutputList(placeholder_outputs);
TF_DeleteAbstractOp(fn_op); TF_DeleteAbstractOp(fn_op);
TF_DeleteAbstractTensor(input_t); TF_DeleteAbstractTensor(input_t);
TF_DeleteAbstractTensor(final_result); TF_DeleteAbstractTensor(final_result);
TF_DeleteTensor(f_t); TF_DeleteTensor(f_t);
TF_DeleteAbstractFunction(func); TF_DeleteAbstractFunction(func);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContext(eager_execution_ctx); TF_DeleteExecutionContext(eager_execution_ctx);
} }
TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Status* s = status.get();
// Start a new function / execution context.
string fn_name = "two_adds";
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Create a first "Add" computing `arg0 + arg1`.
TF_AbstractTensor* add_output1;
{
// Build an abstract operation, inputs and output.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add1", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg0, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
add_output1 = TF_OutputListGet(add_outputs, 0);
TF_DeleteOutputList(add_outputs);
}
// Same with a second "Add" computing `arg1 + arg1`.
TF_AbstractTensor* add_output2;
{
// Build an abstract operation, inputs and output.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add2", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg1, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
add_output2 = TF_OutputListGet(add_outputs, 0);
TF_DeleteOutputList(add_outputs);
}
// Finalize the function by providing the returned values.
TF_AbstractFunction* func;
{
// We want to return the output of both add operations, create a new list
// and populate it.
TF_OutputList* func_outputs = TF_NewOutputList();
TF_OutputListPushBack(func_outputs, add_output1, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_OutputListPushBack(func_outputs, add_output2, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
func = TF_FinalizeFunction(graph_ctx, func_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteOutputList(func_outputs);
}
/**
* We traced so far this function:
*
* def two_adds(a, b):
* my_add1 = a + b
* my_add2 = b + b
* return my_add1, my_add2
*
* Now we will execute this function with an eager context:
*
* output1, output2 = two_adds(2.0, 3.0)
*
* and check that we got 5.0 and 6.0 as results.
*/
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* eager_execution_ctx =
TF_NewEagerExecutionContext(opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TFE_DeleteContextOptions(opts);
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Build the abstract op to run the function.
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Build two abstract input tensors as function arguments.
std::vector<TF_AbstractTensor*> func_args;
{
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
input_eager = TestScalarTensorHandle(eager_ctx, 3.0f);
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
}
TF_OutputList* func_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(func_outputs, 2, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
eager_execution_ctx, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(fn_op);
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
ASSERT_EQ(2, TF_OutputListNumOutputs(func_outputs));
float results[2];
for (int idx = 0; idx < 2; ++idx) {
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
results[idx] = *static_cast<float*>(TF_TensorData(f_t));
TF_DeleteTensor(f_t);
}
ASSERT_EQ(results[0], 5.0);
ASSERT_EQ(results[1], 6.0);
for (int idx = 0; idx < 2; ++idx) {
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
TF_DeleteAbstractTensor(result);
}
TF_DeleteOutputList(func_outputs);
TF_DeleteExecutionContext(eager_execution_ctx);
TF_DeleteAbstractFunction(func);
}
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
@ -193,18 +322,15 @@ TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts); TFE_DeleteContextOptions(opts);
TF_AbstractFunction* func = TF_ExecutionContextToFunction( TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get());
ctx, nullptr, 0, nullptr, 0, nullptr, status.get());
ASSERT_EQ(nullptr, func); ASSERT_EQ(nullptr, func);
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
TF_DeleteExecutionContext(ctx);
} }
TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) { TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph. // Add a placeholder to the graph.
@ -222,10 +348,10 @@ TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
TF_DeleteExecutionContext(graph_ctx); TF_DeleteExecutionContext(graph_ctx);
} }
TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) { TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph. // Add a placeholder to the graph.
@ -243,7 +369,7 @@ TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
TF_DeleteExecutionContext(graph_ctx); TF_DeleteExecutionContext(graph_ctx);
} }
TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) { TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
// Build an Eager context. // Build an Eager context.
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
@ -273,7 +399,8 @@ TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build a Graph context. // Build a Graph context.
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute eager op using graph context. // Execute eager op using graph context.
@ -289,10 +416,11 @@ TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
TF_DeleteExecutionContext(graph_ctx); TF_DeleteExecutionContext(graph_ctx);
} }
TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) { TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph. // Add a placeholder to the graph.
@ -349,5 +477,7 @@ TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
TF_DeleteExecutionContext(eager_execution_ctx); TF_DeleteExecutionContext(eager_execution_ctx);
} }
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef"));
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -101,6 +101,9 @@ class AbstractContextInterface {
// Destroy the step resource container for a training step. // Destroy the step resource container for a training step.
virtual void EndStep() = 0; virtual void EndStep() = 0;
// Block until all pending nodes are finished.
virtual Status AsyncWait() = 0;
protected: protected:
virtual ~AbstractContextInterface() {} virtual ~AbstractContextInterface() {}
}; };

View File

@ -44,6 +44,7 @@ tf_cc_test(
srcs = ["parallel_device_test.cc"], srcs = ["parallel_device_test.cc"],
deps = [ deps = [
":parallel_device", ":parallel_device",
":parallel_device_ops",
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental", "//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
@ -53,3 +54,19 @@ tf_cc_test(
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
], ],
) )
# Note: ParallelDevice-specific ops are experimental and not currently linked in
# to TensorFlow by default, just used in a few tests.
filegroup(
name = "parallel_device_ops_srcs",
srcs = ["parallel_device_ops.cc"],
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
)
cc_library(
name = "parallel_device_ops",
srcs = [":parallel_device_ops_srcs"],
visibility = ["//tensorflow:internal"],
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)

View File

@ -92,6 +92,10 @@ class ParallelDevice {
TFE_TensorHandle* tensor, TFE_TensorHandle* tensor,
TF_Status* status) const; TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;
// Takes a description of a single operation being executed on the // Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with // ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or // its corresponding inputs from the input ParallelTensors (or
@ -208,6 +212,46 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
status); status);
} }
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute( absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs, TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes, const char* operation_name, const TFE_OpAttrs* attributes,
@ -282,6 +326,13 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
} }
result.emplace(std::move(outputs)); result.emplace(std::move(outputs));
return result; return result;
} else if (operation_name == std::string("DeviceID")) {
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(DeviceIDs(context, status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
} }
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results( maybe_parallel_results(

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
// TODO(allenl): Figure out if we need this op, and if so whether we should move
// it to core TF. Right now the eager C API does some checking of op
// registrations before calling into custom devices, but we may be able to avoid
// that.
REGISTER_OP("DeviceID")
.Output("device_id: int64")
.SetIsStateful()
.SetShapeFn(tensorflow::shape_inference::ScalarShape);

View File

@ -278,14 +278,15 @@ TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
} }
// Assert that `handle` is equal to `expected_value`. // Assert that `handle` is equal to `expected_value`.
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) { template <typename value_type>
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero( std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor); TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(expected_value, EXPECT_EQ(expected_value,
*static_cast<float*>(TF_TensorData(value_zero.get()))); *static_cast<value_type*>(TF_TensorData(value_zero.get())));
} }
template <std::size_t num_devices> template <std::size_t num_devices>
@ -343,8 +344,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
ExtractPerDeviceValues(context, read.get(), &components, status.get()); ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(components[0].get(), 20.); ExpectScalarEq<float>(components[0].get(), 20.);
AssertScalarFloatEq(components[1].get(), 20.); ExpectScalarEq<float>(components[1].get(), 20.);
std::string first_device = std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
@ -373,8 +374,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
ExtractPerDeviceValues(context, read.get(), &components, status.get()); ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(components[0].get(), 23.); ExpectScalarEq<float>(components[0].get(), 23.);
AssertScalarFloatEq(components[1].get(), 18.); ExpectScalarEq<float>(components[1].get(), 18.);
std::string first_device = std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
@ -383,6 +384,32 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device); ASSERT_EQ(underlying_devices[1], second_device);
} }
// Compute the device ID twice and verify the result
for (int i = 0; i < 2; ++i) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, result_handle, &components, status.get());
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int64_t>(components[0].get(), 0);
ExpectScalarEq<int64_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
} }
TEST(PARALLEL_DEVICE, TestBasicCPU) { TEST(PARALLEL_DEVICE, TestBasicCPU) {
@ -498,8 +525,8 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// The value of the original tensor is replicated on each device. // The value of the original tensor is replicated on each device.
AssertScalarFloatEq(components[0].get(), 3.); ExpectScalarEq<float>(components[0].get(), 3.);
AssertScalarFloatEq(components[1].get(), 3.); ExpectScalarEq<float>(components[1].get(), 3.);
// Verify that the mirrors are placed on the component devices. // Verify that the mirrors are placed on the component devices.
std::string first_device = std::string first_device =
@ -630,7 +657,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
&second_components, status.get()); &second_components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(second_components[1].get(), 9.); ExpectScalarEq<float>(second_components[1].get(), 9.);
// Verify that the mirrors are placed on the component devices. // Verify that the mirrors are placed on the component devices.
std::string first_device = TFE_TensorHandleBackingDeviceName( std::string first_device = TFE_TensorHandleBackingDeviceName(
@ -644,8 +671,8 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
std::array<TensorHandlePtr, 2> first_components; std::array<TensorHandlePtr, 2> first_components;
ExtractPerDeviceValues(context.get(), second_components[0].get(), ExtractPerDeviceValues(context.get(), second_components[0].get(),
&first_components, status.get()); &first_components, status.get());
AssertScalarFloatEq(first_components[0].get(), 3.); ExpectScalarEq<float>(first_components[0].get(), 3.);
AssertScalarFloatEq(first_components[1].get(), 6.); ExpectScalarEq<float>(first_components[1].get(), 6.);
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(), first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
status.get()); status.get());
@ -806,8 +833,8 @@ TEST(PARALLEL_DEVICE, TestCollective) {
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components, ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get()); status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(result_components[0].get(), 3.); ExpectScalarEq<float>(result_components[0].get(), 3.);
AssertScalarFloatEq(result_components[1].get(), 3.); ExpectScalarEq<float>(result_components[1].get(), 3.);
} }
void RegisterCollectiveMulFunction(TFE_Context* context, void RegisterCollectiveMulFunction(TFE_Context* context,
@ -909,8 +936,8 @@ TEST(PARALLEL_DEVICE, TestFunction) {
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components, ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get()); status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(result_components[0].get(), 7. * 9.); ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
AssertScalarFloatEq(result_components[1].get(), 7. * 9.); ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
std::string first_device = TFE_TensorHandleBackingDeviceName( std::string first_device = TFE_TensorHandleBackingDeviceName(
result_components[0].get(), status.get()); result_components[0].get(), status.get());

View File

@ -31,9 +31,6 @@ cc_library(
"//tensorflow/c/experimental/saved_model/public:concrete_function.h", "//tensorflow/c/experimental/saved_model/public:concrete_function.h",
], ],
copts = tf_copts(), copts = tf_copts(),
# TODO(bmzhao): Remove this as we refactor C API to granular targets,
# so that we can depend on c/eager/c_api_unified_experimental.h.
features = ["-layering_check"],
visibility = [ visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__", "//tensorflow/c/experimental/saved_model/public:__pkg__",
], ],
@ -41,6 +38,8 @@ cc_library(
":concrete_function_type", ":concrete_function_type",
":function_metadata", ":function_metadata",
":function_metadata_type", ":function_metadata_type",
":tensorhandle_list",
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros", "//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:c_api_internal",
@ -160,6 +159,38 @@ cc_library(
], ],
) )
cc_library(
name = "tensorhandle_list",
srcs = [
"tensorhandle_list.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
],
)
cc_library(
name = "tensorhandle_list_type",
hdrs = [
"tensorhandle_list_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/eager:tensor_handle_interface",
],
)
tf_cc_test( tf_cc_test(
name = "saved_model_api_test", name = "saved_model_api_test",
size = "small", size = "small",

View File

@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h" #include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
extern "C" { extern "C" {
@ -29,10 +29,9 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
&tensorflow::unwrap(func)->GetFunctionMetadata())); &tensorflow::unwrap(func)->GetFunctionMetadata()));
} }
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) { const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
// TODO(bmzhao): Refactor TF_OutputList struct definition into a separate TF_ConcreteFunction* func) {
// internal header, and implement this function. return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures());
return nullptr;
} }
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) { TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
#include <stddef.h>
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
extern "C" {
size_t TF_TensorHandleListSize(const TF_TensorHandleList* list) {
return tensorflow::unwrap(list)->size();
}
TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list,
int i) {
return tensorflow::wrap((*tensorflow::unwrap(list))[i]);
}
} // end extern "C"

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
#include <vector>
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
// Internal structures used by the SavedModel C API. These are likely to
// change and should not be depended on.
typedef struct TF_TensorHandleList TF_TensorHandleList;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(
std::vector<tensorflow::AbstractTensorHandleInterface*>,
TF_TensorHandleList)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_

View File

@ -24,6 +24,7 @@ exports_files(
"concrete_function_list.h", "concrete_function_list.h",
"function_metadata.h", "function_metadata.h",
"saved_model_api.h", "saved_model_api.h",
"tensorhandle_list.h",
], ],
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"], visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
) )
@ -39,6 +40,7 @@ cc_library(
":concrete_function_list", ":concrete_function_list",
":function_metadata", ":function_metadata",
":saved_model_api", ":saved_model_api",
":tensorhandle_list",
], ],
) )
@ -61,3 +63,8 @@ alias(
name = "saved_model_api", name = "saved_model_api",
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api", actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
) )
alias(
name = "tensorhandle_list",
actual = "//tensorflow/c/experimental/saved_model/internal:tensorhandle_list",
)

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" #include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_

View File

@ -17,9 +17,9 @@ limitations under the License.
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
#include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" #include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
@ -36,7 +36,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
TF_ConcreteFunction* func); TF_ConcreteFunction* func);
// Returns a list of TensorHandles implicitly captured by this function. // Returns a list of TensorHandles implicitly captured by this function.
TF_CAPI_EXPORT extern TF_OutputList* TF_ConcreteFunctionGetCaptures( TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
TF_ConcreteFunction* func); TF_ConcreteFunction* func);
// Returns a TFE_Op suitable for executing this function. // Returns a TFE_Op suitable for executing this function.

View File

@ -21,19 +21,27 @@ limitations under the License.
#include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that is acts like a list of TF_ConcreteFunction pointers. // An opaque type that is acts like a list of TF_ConcreteFunction pointers.
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList; typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
// Returns the size of `list`. // Returns the size of `list`.
TF_CAPI_EXPORT size_t TF_CAPI_EXPORT extern size_t TF_ConcreteFunctionListSize(
TF_ConcreteFunctionListSize(TF_ConcreteFunctionList* list); TF_ConcreteFunctionList* list);
// Returns the `i`th TF_ConcreteFunction in the list. // Returns the `i`th TF_ConcreteFunction in the list.
TF_CAPI_EXPORT TF_ConcreteFunction* TF_ConcreteFunctionListGet( TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_ConcreteFunctionListGet(
TF_ConcreteFunctionList* list, int i); TF_ConcreteFunctionList* list, int i);
// Deletes `list`. // Deletes `list`.
TF_CAPI_EXPORT void TF_DeleteConcreteFunctionList( TF_CAPI_EXPORT extern void TF_DeleteConcreteFunctionList(
TF_ConcreteFunctionList* list); TF_ConcreteFunctionList* list);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_

View File

@ -0,0 +1,43 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
#include <stddef.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
typedef struct TF_TensorHandleList TF_TensorHandleList;
// Returns the size of `list`.
TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize(
const TF_TensorHandleList* list);
// Returns the `i`th TFE_TensorHandle in the list.
TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet(
const TF_TensorHandleList* list, int i);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_

View File

@ -62,3 +62,17 @@ cc_library(
"//tensorflow/c:tf_tensor", "//tensorflow/c:tf_tensor",
], ],
) )
cc_library(
name = "tensorhandle",
hdrs = [
"tensorhandle.h",
],
deps = [
":runtime",
":status",
":tensor",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
],
)

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// Runtime represents an opaque instance of a Tensorflow runtime, with its own // Runtime represents an opaque instance of a Tensorflow runtime, with its own
@ -40,6 +41,7 @@ class Runtime {
private: private:
friend class RuntimeBuilder; friend class RuntimeBuilder;
friend class SavedModelAPI; friend class SavedModelAPI;
friend class TensorHandle;
// Wraps a TFE_Context. Takes ownership of ctx. // Wraps a TFE_Context. Takes ownership of ctx.
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {} explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
@ -63,6 +65,7 @@ class Runtime {
}; };
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/cc/experimental/base/public/status.h" #include "tensorflow/cc/experimental/base/public/status.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime. // RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
@ -79,6 +80,7 @@ inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
} }
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// Status is a wrapper around an error code and an optional error message. // Status is a wrapper around an error code and an optional error message.
@ -57,6 +58,7 @@ class Status {
friend class RuntimeBuilder; friend class RuntimeBuilder;
friend class Runtime; friend class Runtime;
friend class SavedModelAPI; friend class SavedModelAPI;
friend class TensorHandle;
// Wraps a TF_Status*, and takes ownership of it. // Wraps a TF_Status*, and takes ownership of it.
explicit Status(TF_Status* status) : status_(status) {} explicit Status(TF_Status* status) : status_(status) {}
@ -88,6 +90,7 @@ inline void Status::SetStatus(TF_Code code, const std::string& msg) {
} }
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/cc/experimental/base/public/status.h" #include "tensorflow/cc/experimental/base/public/status.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// Tensor represents an n-dimensional array of values. // Tensor represents an n-dimensional array of values.
@ -168,6 +169,7 @@ inline Tensor Tensor::FromBuffer(TF_DataType dtype,
} }
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_ #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_

View File

@ -0,0 +1,98 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
#include <memory>
#include <vector>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/cc/experimental/base/public/tensor.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// An opaque representation of a tensor computed/managed by the Tensorflow
// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer
// to tensors placed in memory of different devices or remote address spaces.
// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created
// from it.
class TensorHandle {
public:
// Unwraps a Tensor from the given TensorHandle. If an error occurred,
// status->ok() will be false, and the returned Tensor must not be used.
Tensor Resolve(Status* status);
// Constructs a TensorHandle from a Tensor. If an error occurred,
// status->ok() will be false, and the returned TensorHandle must not be used.
static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime,
Status* status);
// TensorHandle is movable, and not copyable
TensorHandle(TensorHandle&&) = default;
TensorHandle& operator=(TensorHandle&&) = default;
private:
// Wraps a TFE_TensorHandle. Takes ownership of handle.
explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {}
// TensorHandle is not copyable
TensorHandle(const TensorHandle&) = delete;
TensorHandle& operator=(const TensorHandle&) = delete;
// Returns the underlying TFE_TensorHandle that this object wraps.
// This object retains ownership of the pointer.
TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); }
// Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle,
// and takes ownership of handle.
void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); }
struct TFETensorHandleDeleter {
void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
};
std::unique_ptr<TFE_TensorHandle, TFETensorHandleDeleter> handle_;
};
inline Tensor TensorHandle::Resolve(Status* status) {
TF_Tensor* tensor =
TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus());
if (!status->ok()) {
return Tensor(nullptr);
}
return Tensor(tensor);
}
inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor,
const Runtime& runtime,
Status* status) {
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor(
runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus());
if (!status->ok()) {
return TensorHandle(nullptr);
}
return TensorHandle(tensor_handle);
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_

View File

@ -5,12 +5,22 @@ package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
cc_library(
name = "tensor_types_test_util",
testonly = True,
hdrs = ["tensor_types_test_util.h"],
deps = [
"//tensorflow/c:tf_datatype",
],
)
tf_cc_test( tf_cc_test(
name = "tensor_test", name = "tensor_test",
srcs = [ srcs = [
"tensor_test.cc", "tensor_test.cc",
], ],
deps = [ deps = [
":tensor_types_test_util",
"//tensorflow/c:tf_datatype", "//tensorflow/c:tf_datatype",
"//tensorflow/cc/experimental/base/public:status", "//tensorflow/cc/experimental/base/public:status",
"//tensorflow/cc/experimental/base/public:tensor", "//tensorflow/cc/experimental/base/public:tensor",
@ -19,3 +29,22 @@ tf_cc_test(
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
], ],
) )
tf_cc_test(
name = "tensorhandle_test",
srcs = [
"tensorhandle_test.cc",
],
deps = [
":tensor_types_test_util",
"//tensorflow/c:tf_datatype",
"//tensorflow/cc/experimental/base/public:runtime",
"//tensorflow/cc/experimental/base/public:runtime_builder",
"//tensorflow/cc/experimental/base/public:status",
"//tensorflow/cc/experimental/base/public:tensor",
"//tensorflow/cc/experimental/base/public:tensorhandle",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -16,69 +16,22 @@ limitations under the License.
#include "tensorflow/cc/experimental/base/public/tensor.h" #include "tensorflow/cc/experimental/base/public/tensor.h"
#include <stddef.h> #include <stddef.h>
#include <stdint.h>
#include <cstdint>
#include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_datatype.h"
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace { namespace {
// Each of the following struct types have two members: a kDType that using tensorflow::experimental::cc::Status;
// corresponds to a TF_Datatype enum value, and a typedef "type" using tensorflow::experimental::cc::Tensor;
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
// tests via GoogleTest's TypedTests:
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
struct FloatType {
using type = float;
static constexpr TF_DataType kDType = TF_FLOAT;
};
struct DoubleType { using SimpleTypes = ::testing::Types<
using type = double; tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
static constexpr TF_DataType kDType = TF_DOUBLE; tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
}; tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
struct Int32Type {
using type = int32_t;
static constexpr TF_DataType kDType = TF_INT32;
};
struct UINT8Type {
using type = uint8_t;
static constexpr TF_DataType kDType = TF_UINT8;
};
struct INT8Type {
using type = int8_t;
static constexpr TF_DataType kDType = TF_INT8;
};
struct INT64Type {
using type = int64_t;
static constexpr TF_DataType kDType = TF_INT64;
};
struct UINT16Type {
using type = uint16_t;
static constexpr TF_DataType kDType = TF_UINT16;
};
struct UINT32Type {
using type = uint32_t;
static constexpr TF_DataType kDType = TF_UINT32;
};
struct UINT64Type {
using type = uint64_t;
static constexpr TF_DataType kDType = TF_UINT64;
};
using SimpleTypes =
::testing::Types<FloatType, DoubleType, Int32Type, UINT8Type, INT8Type,
INT64Type, UINT16Type, UINT32Type, UINT64Type>;
template <typename T> template <typename T>
class ConstructScalarTensorTest : public ::testing::Test {}; class ConstructScalarTensorTest : public ::testing::Test {};
@ -88,14 +41,13 @@ TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes);
// and verifies the expected dimensions, dtype, value, number of bytes, and // and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements. // number of elements.
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) { TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
cc::Status status; Status status;
TF_DataType dtype = TypeParam::kDType; TF_DataType dtype = TypeParam::kDType;
typename TypeParam::type value = 42; typename TypeParam::type value = 42;
cc::Tensor tensor = Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
cc::Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{}, /*data=*/&value,
/*data=*/&value, /*len=*/sizeof(value),
/*len=*/sizeof(value), /*deleter=*/[](void*, size_t) {}, &status);
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message(); ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 0); EXPECT_EQ(tensor.dims(), 0);
@ -113,7 +65,7 @@ TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes);
// and verifies the expected dimensions, dtype, value, number of bytes, and // and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements. // number of elements.
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) { TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
cc::Status status; Status status;
TF_DataType dtype = TypeParam::kDType; TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype. // This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29}; std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
@ -121,7 +73,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
std::vector<int64_t> shape; std::vector<int64_t> shape;
shape.push_back(value.size()); shape.push_back(value.size());
cc::Tensor tensor = cc::Tensor::FromBuffer( Tensor tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape, /*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(), /*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type), /*len=*/value.size() * sizeof(typename TypeParam::type),
@ -130,7 +82,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
EXPECT_EQ(tensor.dims(), 1); EXPECT_EQ(tensor.dims(), 1);
EXPECT_EQ(tensor.dtype(), dtype); EXPECT_EQ(tensor.dtype(), dtype);
gtl::ArraySlice<typename TypeParam::type> tensor_view( tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size()); reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42); EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100); EXPECT_EQ(tensor_view[1], 100);
@ -152,14 +104,14 @@ TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes);
// and verifies the expected dimensions, dtype, value, number of bytes, and // and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements. // number of elements.
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) { TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
cc::Status status; Status status;
TF_DataType dtype = TypeParam::kDType; TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype. // This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29}; std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 2 vector with shape 2 x 3. // Shape is Rank 2 vector with shape 2 x 3.
std::vector<int64_t> shape({2, 3}); std::vector<int64_t> shape({2, 3});
cc::Tensor tensor = cc::Tensor::FromBuffer( Tensor tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape, /*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(), /*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type), /*len=*/value.size() * sizeof(typename TypeParam::type),
@ -169,7 +121,7 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
EXPECT_EQ(tensor.dims(), 2); EXPECT_EQ(tensor.dims(), 2);
EXPECT_EQ(tensor.dtype(), dtype); EXPECT_EQ(tensor.dtype(), dtype);
gtl::ArraySlice<typename TypeParam::type> tensor_view( tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size()); reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42); EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100); EXPECT_EQ(tensor_view[1], 100);
@ -185,22 +137,22 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
TEST(CPPTensorAPI, ConstructTensorFromBuffer) { TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
bool done = false; bool done = false;
cc::Status status; Status status;
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100}); std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
{ {
// data_vector is a rank 1 tensor. // data_vector is a rank 1 tensor.
std::vector<int64_t> shape; std::vector<int64_t> shape;
shape.push_back(data_vector.size()); shape.push_back(data_vector.size());
cc::Tensor::DeleterCallback callback = [&done](void* data, size_t len) { Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
done = true; done = true;
}; };
cc::Tensor tensor = Tensor tensor =
cc::Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape, Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
/*data=*/data_vector.data(), /*data=*/data_vector.data(),
/*len=*/data_vector.size() * sizeof(int32_t), /*len=*/data_vector.size() * sizeof(int32_t),
/*deleter=*/callback, &status); /*deleter=*/callback, &status);
ASSERT_TRUE(status.ok()) << status.message(); ASSERT_TRUE(status.ok()) << status.message();
} }
// At this point, tensor has been destroyed, and the deleter callback should // At this point, tensor has been destroyed, and the deleter callback should
@ -209,4 +161,3 @@ TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
} }
} // namespace } // namespace
} // namespace tensorflow

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
#include <stdint.h>
#include "tensorflow/c/tf_datatype.h"
namespace tensorflow {
// Each of the following struct types have two members: a kDType that
// corresponds to a TF_Datatype enum value, and a typedef "type"
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
// tests via GoogleTest's TypedTests:
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
struct FloatType {
using type = float;
static constexpr TF_DataType kDType = TF_FLOAT;
};
struct DoubleType {
using type = double;
static constexpr TF_DataType kDType = TF_DOUBLE;
};
struct Int32Type {
using type = int32_t;
static constexpr TF_DataType kDType = TF_INT32;
};
struct UINT8Type {
using type = uint8_t;
static constexpr TF_DataType kDType = TF_UINT8;
};
struct INT8Type {
using type = int8_t;
static constexpr TF_DataType kDType = TF_INT8;
};
struct INT64Type {
using type = int64_t;
static constexpr TF_DataType kDType = TF_INT64;
};
struct UINT16Type {
using type = uint16_t;
static constexpr TF_DataType kDType = TF_UINT16;
};
struct UINT32Type {
using type = uint32_t;
static constexpr TF_DataType kDType = TF_UINT32;
};
struct UINT64Type {
using type = uint64_t;
static constexpr TF_DataType kDType = TF_UINT64;
};
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_

View File

@ -0,0 +1,184 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/experimental/base/public/tensorhandle.h"
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
#include "tensorflow/cc/experimental/base/public/tensor.h"
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
using tensorflow::experimental::cc::Runtime;
using tensorflow::experimental::cc::RuntimeBuilder;
using tensorflow::experimental::cc::Status;
using tensorflow::experimental::cc::Tensor;
using tensorflow::experimental::cc::TensorHandle;
using SimpleTypes = ::testing::Types<
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
template <typename T>
class ConstructScalarTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
// verify the expected dims, dtype, value, num bytes, and num elements.
TYPED_TEST(ConstructScalarTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
typename TypeParam::type value = 42;
Tensor original_tensor =
Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
/*data=*/&value,
/*len=*/sizeof(value),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 0);
EXPECT_EQ(tensor.dtype(), dtype);
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), 1);
}
template <typename T>
class Construct1DTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct1DTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 1 vector.
std::vector<int64_t> shape;
shape.push_back(value.size());
Tensor original_tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 1);
EXPECT_EQ(tensor.dtype(), dtype);
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
EXPECT_EQ(tensor_view[2], 0);
EXPECT_EQ(tensor_view[3], 1);
EXPECT_EQ(tensor_view[4], 4);
EXPECT_EQ(tensor_view[5], 29);
EXPECT_EQ(tensor.num_bytes(),
value.size() * sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), value.size());
}
template <typename T>
class Construct2DTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct2DTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 2 vector with shape 2 x 3.
std::vector<int64_t> shape({2, 3});
Tensor original_tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 2);
EXPECT_EQ(tensor.dtype(), dtype);
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
EXPECT_EQ(tensor_view[2], 0);
EXPECT_EQ(tensor_view[3], 1);
EXPECT_EQ(tensor_view[4], 4);
EXPECT_EQ(tensor_view[5], 29);
EXPECT_EQ(tensor.num_bytes(),
value.size() * sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), value.size());
}
} // namespace
} // namespace tensorflow

View File

@ -84,7 +84,7 @@ cc_library(
"//tensorflow/core:ops", "//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
]) + if_android([ ]) + if_android([
"//tensorflow/core:android_tensorflow_lib", "//tensorflow/core:portable_tensorflow_lib",
]), ]),
) )

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h" #include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI. // ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
@ -54,6 +55,7 @@ inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
} }
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h" #include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// ConcreteFunctionList helps convert an opaque pointer to an array of // ConcreteFunctionList helps convert an opaque pointer to an array of
@ -56,6 +57,7 @@ inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
} }
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" #include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// FunctionMetadata stores additional function information, including // FunctionMetadata stores additional function information, including
@ -40,6 +41,7 @@ class FunctionMetadata final {
}; };
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h" #include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
namespace tensorflow { namespace tensorflow {
namespace experimental {
namespace cc { namespace cc {
// SavedModelAPI offers a way to load Tensorflow Saved Models // SavedModelAPI offers a way to load Tensorflow Saved Models
@ -155,6 +156,7 @@ inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
} }
} // namespace cc } // namespace cc
} // namespace experimental
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_

View File

@ -26,10 +26,14 @@ limitations under the License.
#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace { namespace {
using tensorflow::experimental::cc::Runtime;
using tensorflow::experimental::cc::RuntimeBuilder;
using tensorflow::experimental::cc::SavedModelAPI;
using tensorflow::experimental::cc::Status;
constexpr char kTestData[] = "cc/saved_model/testdata"; constexpr char kTestData[] = "cc/saved_model/testdata";
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) { std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
@ -43,21 +47,21 @@ std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {}; class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) { TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
cc::Status status; Status status;
cc::RuntimeBuilder builder; RuntimeBuilder builder;
bool use_tfrt = GetParam(); bool use_tfrt = GetParam();
if (use_tfrt) { if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
} }
builder.SetUseTFRT(use_tfrt); builder.SetUseTFRT(use_tfrt);
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status); std::unique_ptr<Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message(); ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unordered_set<std::string> tags = {"serve"}; std::unordered_set<std::string> tags = {"serve"};
std::unique_ptr<cc::SavedModelAPI> model = std::unique_ptr<SavedModelAPI> model =
cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags); SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented. // TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*, // That unblocks writing other tests that require a TF_SavedModel*,
@ -67,20 +71,20 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
} }
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) { TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
cc::Status status; Status status;
cc::RuntimeBuilder builder; RuntimeBuilder builder;
bool use_tfrt = GetParam(); bool use_tfrt = GetParam();
if (use_tfrt) { if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
} }
builder.SetUseTFRT(use_tfrt); builder.SetUseTFRT(use_tfrt);
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status); std::unique_ptr<Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message(); ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unique_ptr<cc::SavedModelAPI> model = std::unique_ptr<SavedModelAPI> model =
cc::SavedModelAPI::Load(model_dir, *runtime, &status); SavedModelAPI::Load(model_dir, *runtime, &status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented. // TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*, // That unblocks writing other tests that require a TF_SavedModel*,
@ -94,4 +98,3 @@ INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
} // namespace } // namespace
} // namespace tensorflow

View File

@ -42,7 +42,8 @@ def tf_library(
mlir_components = "None", mlir_components = "None",
deps = None, deps = None,
tags = []): tags = []):
"""Runs tfcompile to compile a TensorFlow graph into executable code. """Runs tfcompile to compile a TensorFlow graph into executable code with fast
math enabled on cpu.
Given an invocation of tf_library(name="foo", ...), generates the following Given an invocation of tf_library(name="foo", ...), generates the following
build targets: build targets:
@ -207,6 +208,15 @@ def tf_library(
srcs.append(debug_info) srcs.append(debug_info)
debug_info_flag = " --debug_info=$(location " + debug_info + ")" debug_info_flag = " --debug_info=$(location " + debug_info + ")"
default_fast_math_xla_flags = ("XLA_FLAGS='" +
"--xla_cpu_enable_fast_math=true " +
"--xla_cpu_fast_math_honor_nans=false " +
"--xla_cpu_fast_math_honor_infs=false " +
"--xla_cpu_fast_math_honor_functions=false " +
"--xla_cpu_fast_math_honor_division=false " +
"--xla_cpu_enable_fast_min_max=true " +
"$${XLA_FLAGS:-}' ")
native.genrule( native.genrule(
name = ("gen_" + name), name = ("gen_" + name),
srcs = srcs, srcs = srcs,
@ -216,6 +226,7 @@ def tf_library(
function_object_file, function_object_file,
], ],
cmd = ( cmd = (
default_fast_math_xla_flags +
"CUDA_VISIBLE_DEVICES='' " + "CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" + "$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" + " --graph=$(location " + tfcompile_graph + ")" +
@ -256,6 +267,7 @@ def tf_library(
session_module_pb, session_module_pb,
], ],
cmd = ( cmd = (
default_fast_math_xla_flags +
"CUDA_VISIBLE_DEVICES='' " + "CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" + "$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" + " --graph=$(location " + tfcompile_graph + ")" +

View File

@ -67,6 +67,8 @@ int main(int argc, char** argv) {
flags.entry_point = "entry"; flags.entry_point = "entry";
flags.debug_info_path_begin_marker = ""; flags.debug_info_path_begin_marker = "";
// Note that tfcompile.bzl's tf_library macro sets fast math flags as that is
// generally the preferred case.
std::vector<tensorflow::Flag> flag_list; std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags); AppendMainFlags(&flag_list, &flags);
xla::AppendDebugOptionsFlags(&flag_list); xla::AppendDebugOptionsFlags(&flag_list);

View File

@ -251,7 +251,7 @@ cc_library(
visibility = [":friends"], visibility = [":friends"],
deps = select({ deps = select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib", "//tensorflow/core:portable_tensorflow_lib",
], ],
"//conditions:default": [ "//conditions:default": [
"//tensorflow/core:graph", "//tensorflow/core:graph",

View File

@ -77,10 +77,6 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes", "//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
"//tensorflow/compiler/mlir/tfrt:lower_tf_to_tfd_alwayslink",
"//tensorflow/compiler/mlir/tfrt:runtime_fallback_opdefs_alwayslink",
"//tensorflow/compiler/mlir/tfrt:tf_legalize_to_tfrt",
"//tensorflow/compiler/mlir/tfrt:tf_to_corert",
], ],
) )
@ -152,7 +148,6 @@ tf_cc_binary(
"//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/compiler/mlir/tensorflow:translate_registration", "//tensorflow/compiler/mlir/tensorflow:translate_registration",
"//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op", "//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op",
"//tensorflow/compiler/mlir/tfrt:compatibility_analysis",
"//tensorflow/compiler/mlir/xla:xla_mlir_translate", "//tensorflow/compiler/mlir/xla:xla_mlir_translate",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",

View File

@ -31,7 +31,7 @@ filegroup(
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
], ],
) )

View File

@ -799,11 +799,6 @@ Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions( Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
std::string node_def_str;
if (!node_def.SerializeToString(&node_def_str)) {
return emitError(loc, "failed to serialize tensorflow node_def"),
llvm::None;
}
auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc); auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
return builder_.CreateVector(flex_builder->GetBuffer()); return builder_.CreateVector(flex_builder->GetBuffer());
} }
@ -813,9 +808,13 @@ Translator::CreateFlexBuilderWithNodeAttrs(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
auto flex_builder = absl::make_unique<flexbuffers::Builder>(); auto flex_builder = absl::make_unique<flexbuffers::Builder>();
size_t map_start = flex_builder->StartMap(); size_t map_start = flex_builder->StartMap();
for (const auto& pair : node_def.attr()) { using Item = std::pair<std::string, ::tensorflow::AttrValue>;
std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
std::sort(attrs.begin(), attrs.end(),
[](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
for (const Item& pair : attrs) {
const char* key = pair.first.c_str(); const char* key = pair.first.c_str();
const auto& attr = pair.second; const ::tensorflow::AttrValue& attr = pair.second;
switch (attr.value_case()) { switch (attr.value_case()) {
case ::tensorflow::AttrValue::kS: case ::tensorflow::AttrValue::kS:
flex_builder->String(key, attr.s()); flex_builder->String(key, attr.s());

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project #include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"

View File

@ -20,7 +20,7 @@ limitations under the License.
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffects.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td" include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
include "tensorflow/compiler/mlir/lite/quantization/quantization.td" include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
@ -414,9 +414,9 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
}]; }];
let arguments = ( let arguments = (
ins TFL_TensorOf<[F32, QI8, QUI8]>:$input, ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
TFL_TensorOf<[F32, QI8, QUI8]>:$filter, TFL_TensorOf<[F32, QI8, QUI8]>:$filter,
TFL_TensorOfOrNone<[F32, I32]>:$bias, TFL_TensorOfOrNone<[F32, I32, I64]>:$bias,
I32Attr:$dilation_h_factor, I32Attr:$dilation_h_factor,
I32Attr:$dilation_w_factor, I32Attr:$dilation_w_factor,
TFL_AFAttr:$fused_activation_function, TFL_AFAttr:$fused_activation_function,
@ -425,7 +425,7 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
I32Attr:$stride_w I32Attr:$stride_w
); );
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output); let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
let hasOptions = 0b1; let hasOptions = 0b1;
} }
@ -1561,10 +1561,12 @@ def TFL_GreaterOp : TFL_Op<"greater", [
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
} }
def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect, def TFL_HardSwishOp: TFL_Op<"hard_swish", [
SameOperandsAndResultShape, NoSideEffect,
SameOperandsAndResultType, SameOperandsAndResultShape,
TFL_GpuTargetOp]> { PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_GpuTargetOp]> {
let summary = "Hardswish activation function."; let summary = "Hardswish activation function.";
let description = [{ let description = [{
Computes hard-swish activation function Computes hard-swish activation function
@ -1574,7 +1576,7 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect,
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$input); let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$input);
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$out); let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$output);
let hasOptions = 0; let hasOptions = 0;
} }
@ -1606,7 +1608,8 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect,
def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
SameOperandsAndResultShape, SameOperandsAndResultShape,
NoSideEffect, NoSideEffect,
SameOperandsAndResultType]> { PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
let summary = "Leaky Relu operator"; let summary = "Leaky Relu operator";
let description = [{ let description = [{
@ -1740,7 +1743,8 @@ def TFL_LogOp: TFL_Op<"log", [
def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
NoSideEffect, NoSideEffect,
SameOperandsAndResultShape, SameOperandsAndResultShape,
SameOperandsAndResultType, PredOpTrait<"x and y must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
// zero_point = max_value // zero_point = max_value
// scale = -log_softmax_output_min / (max_value + 1) // scale = -log_softmax_output_min / (max_value + 1)
FixedResultScale<Int8UniformQuantizedType<127, 625, -4>>, FixedResultScale<Int8UniformQuantizedType<127, 625, -4>>,
@ -1896,11 +1900,11 @@ Rounds the values of a tensor to the nearest integer, element-wise.
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32]>:$x TFL_FpTensor:$x
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32]>:$y TFL_FpTensor:$y
); );
} }
@ -2443,9 +2447,9 @@ def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
Computes element-wise reverse square root of input Computes element-wise reverse square root of input
}]; }];
let arguments = (ins AnyTensor:$x); let arguments = (ins TFL_FpTensor:$x);
let results = (outs AnyTensor:$y); let results = (outs TFL_FpTensor:$y);
let hasFolder = 1; let hasFolder = 1;
} }
@ -3361,9 +3365,11 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [
let results = (outs AnyTensor:$output); let results = (outs AnyTensor:$output);
} }
def TFL_DensifyOp: TFL_Op<"densify", [NoSideEffect, def TFL_DensifyOp: TFL_Op<"densify", [
SameOperandsAndResultType, NoSideEffect,
NoQuantizableResult]> { PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoQuantizableResult]> {
let summary = "Densify operator"; let summary = "Densify operator";
let description = [{ let description = [{

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mlir { namespace mlir {
namespace lite { namespace lite {
@ -38,6 +39,7 @@ namespace lite {
TfLiteStatus QuantizeModel( TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type, const tflite::TensorType& output_type,
const tflite::TensorType& inference_type,
const std::unordered_set<std::string>& operator_names, const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize, bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder, flatbuffers::FlatBufferBuilder* builder,
@ -73,7 +75,7 @@ TfLiteStatus QuantizeModel(
// Apply quantization passes // Apply quantization passes
PassManager pm(module->getContext()); PassManager pm(module->getContext());
TFL::QuantizationSpecs quant_specs; TFL::QuantizationSpecs quant_specs;
quant_specs.inference_type = tensorflow::DT_QINT8; quant_specs.inference_type = tflite::TflTypeToTfType(inference_type);
quant_specs.post_training_quantization = true; quant_specs.post_training_quantization = true;
quant_specs.disable_per_channel = disable_per_channel; quant_specs.disable_per_channel = disable_per_channel;
@ -81,8 +83,10 @@ TfLiteStatus QuantizeModel(
auto input_tf_type = tflite::TflTypeToTfType(input_type); auto input_tf_type = tflite::TflTypeToTfType(input_type);
if (input_tf_type == tensorflow::DT_FLOAT) { if (input_tf_type == tensorflow::DT_FLOAT) {
emit_adaptor = true; emit_adaptor = true;
} else if (input_tf_type == tensorflow::DT_UINT8) { } else if (input_tf_type == tensorflow::DT_UINT8 ||
quant_specs.inference_type = tensorflow::DT_QUINT8; input_tf_type == tensorflow::DT_INT8 ||
input_tf_type == tensorflow::DT_INT16) {
quant_specs.inference_type = input_tf_type;
} }
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs)); pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));

View File

@ -26,11 +26,13 @@ namespace mlir {
namespace lite { namespace lite {
// Quantize the `input_model` and write the result to a flatbuffer `builder`. // Quantize the `input_model` and write the result to a flatbuffer `builder`.
// The `input_type` and `output_type` can be float32/qint8/int8. // The `input_type`, `output_type` and `inference_type` can be
// float32/qint8/int8/int16.
// Return partially quantized model if `fully_quantize` is false. // Return partially quantized model if `fully_quantize` is false.
TfLiteStatus QuantizeModel( TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type, const tflite::TensorType& output_type,
const tflite::TensorType& inference_type,
const std::unordered_set<std::string>& operator_names, const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize, bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder, flatbuffers::FlatBufferBuilder* builder,

View File

@ -46,7 +46,8 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer,
tflite::StderrReporter error_reporter; tflite::StderrReporter error_reporter;
return mlir::lite::QuantizeModel( return mlir::lite::QuantizeModel(
*model, tflite::TensorType_INT8, tflite::TensorType_INT8, {}, *model, tflite::TensorType_INT8, tflite::TensorType_INT8,
tflite::TensorType_INT8, {},
/*disable_per_channel=*/false, /*disable_per_channel=*/false,
/*fully_quantize=*/true, builder, &error_reporter); /*fully_quantize=*/true, builder, &error_reporter);
} }

View File

@ -90,7 +90,7 @@ struct QuantizationSpecs {
bool RunWeightQuantization() const { return weight_quantization; } bool RunWeightQuantization() const { return weight_quantization; }
// Whether this inference type represents a signed storage type. // Whether this inference type represents a signed storage type.
bool IsSignedInferenceType() { bool IsSignedInferenceType() const {
switch (inference_type) { switch (inference_type) {
case tensorflow::DT_QUINT8: case tensorflow::DT_QUINT8:
case tensorflow::DT_QUINT16: case tensorflow::DT_QUINT16:
@ -102,7 +102,7 @@ struct QuantizationSpecs {
// Gets the width of this quantization type. Returns 0 if it isn't a // Gets the width of this quantization type. Returns 0 if it isn't a
// quantization type. // quantization type.
int64_t GetQuantizationTypeWidth() { int64_t GetQuantizationTypeWidth() const {
switch (inference_type) { switch (inference_type) {
case tensorflow::DT_QINT8: case tensorflow::DT_QINT8:
case tensorflow::DT_QUINT8: case tensorflow::DT_QUINT8:

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
@ -35,6 +36,7 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
namespace mlir { namespace mlir {
@ -363,6 +365,54 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
} }
}; };
// Fold Extra Requantize ops if the preceding ops has free scale requirement.
template <typename RQ>
struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
explicit FoldTrivalRequantizeOp(MLIRContext* context)
: OpRewritePattern<RQ>(context, 1) {}
LogicalResult matchAndRewrite(RQ op,
PatternRewriter& rewriter) const override {
Value pre_quantized = op.input();
auto pre_quantized_type =
quant::QuantizedType::getQuantizedElementType(pre_quantized.getType());
if (!pre_quantized_type) return failure();
Operation* def = pre_quantized.getDefiningOp();
if (!def) return failure();
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() ||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
return failure();
}
op.emitWarning("Remove trivial `rescale` op. Please fix the source graph.");
llvm::SmallVector<Type, 4> new_output_types;
for (auto result : def->getResults()) {
result.getUsers().begin()->dump();
op.dump();
if (result.hasOneUse() && *result.getUsers().begin() == op) {
new_output_types.push_back(op.qtype());
} else {
new_output_types.push_back(result.getType());
}
}
// Remove this rescale op.
rewriter.replaceOp(op, {pre_quantized});
// Replace the output scale of the preceding op.
rewriter.setInsertionPointAfter(def);
OperationState new_state(def->getLoc(), def->getName().getStringRef(),
def->getOperands(), new_output_types,
def->getAttrs());
Operation* new_op = rewriter.createOperation(new_state);
rewriter.replaceOp(def, new_op->getResults());
return success();
}
};
// Given a quantized type `input`, magnifying its scales by the factor stored in // Given a quantized type `input`, magnifying its scales by the factor stored in
// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the // `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
// dimension size of `input` or isn't floating-point, nullptr will be returned. // dimension size of `input` or isn't floating-point, nullptr will be returned.

View File

@ -11,9 +11,9 @@ func @reshape_removeAdjacent(tensor<4x4x4xf32>) -> tensor<64xf32> {
return %1 : tensor<64xf32> return %1 : tensor<64xf32>
// CHECK-LABEL: func @reshape_removeAdjacent // CHECK-LABEL: func @reshape_removeAdjacent
// CHECK: %cst = constant dense<64> : tensor<1xi32> // CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> // CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return // CHECK: return %[[RESHAPE]]
} }
// Checks that tfl.reshape should be removed if its output has more than one // Checks that tfl.reshape should be removed if its output has more than one
@ -29,11 +29,11 @@ func @reshape_removeAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> tensor<64xf32>
return %3 : tensor<64xf32> return %3 : tensor<64xf32>
// CHECK-LABEL: func @reshape_removeAdjacentWithMultipleUse // CHECK-LABEL: func @reshape_removeAdjacentWithMultipleUse
// CHECK: %cst = constant dense<64> : tensor<1xi32> // CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> // CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %1 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> // CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %2 = addf %0, %1 // CHECK: %[[RESULT:.*]] = addf %[[RESHAPE_1]], %[[RESHAPE_2]]
// CHECK: return %2 // CHECK: return %[[RESULT]]
} }
// Checks that tfl.reshape should be kept if its output has more than one // Checks that tfl.reshape should be kept if its output has more than one
@ -47,11 +47,11 @@ func @reshape_keepAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> (tensor<16x4xf32
return %0, %1 : tensor<16x4xf32>, tensor<64xf32> return %0, %1 : tensor<16x4xf32>, tensor<64xf32>
// CHECK-LABEL: func @reshape_keepAdjacentWithMultipleUse // CHECK-LABEL: func @reshape_keepAdjacentWithMultipleUse
// CHECK: %cst = constant dense<[16, 4]> : tensor<2xi32> // CHECK: %[[CST:.*]] = constant dense<[16, 4]> : tensor<2xi32>
// CHECK: %cst_0 = constant dense<64> : tensor<1xi32> // CHECK: %[[CST_0:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32> // CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32>
// CHECK: %1 = "tfl.reshape"(%arg0, %cst_0) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> // CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST_0]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return %0, %1 // CHECK: return %[[RESHAPE_1]], %[[RESHAPE_2]]
} }
// Checks that tfl.reshape should be removed if its output type is the same // Checks that tfl.reshape should be removed if its output type is the same

View File

@ -8,13 +8,13 @@ func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>,
%2 = constant dense< 3.5> : tensor<4xf32> %2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32> %3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<3.500000e+00> : tensor<4xf32> // CHECK: %[[CST:.*]] = constant dense<3.500000e+00> : tensor<4xf32>
// CHECK: %cst_0 = constant dense<-5.000000e-01> : tensor<4xf32> // CHECK: %[[CST_0:.*]] = constant dense<-5.000000e-01> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<6.000000e+00> : tensor<f32> // CHECK: %[[CST_1:.*]] = constant dense<6.000000e+00> : tensor<f32>
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32> // CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32>
// CHECK: %cst_3 = constant dense<5.000000e+00> : tensor<4xf32> // CHECK: %[[CST_3:.*]] = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %cst_4 = constant dense<3.000000e+00> : tensor<4xf32> // CHECK: %[[CST_4:.*]] = constant dense<3.000000e+00> : tensor<4xf32>
// CHECK: %0 = tfl.add %cst, %cst_0 {fused_activation_function = "SIGN_BIT"} : tensor<4xf32> // CHECK: %0 = tfl.add %[[CST]], %[[CST_0]] {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32> %5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32> %6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -33,10 +33,10 @@ func @add_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%2 = constant dense< 4> : tensor<4xi32> %2 = constant dense< 4> : tensor<4xi32>
%3 = constant dense<-2> : tensor<4xi32> %3 = constant dense<-2> : tensor<4xi32>
// CHECK: %cst = constant dense<9> : tensor<i32> // CHECK: %[[CST:.*]] = constant dense<9> : tensor<i32>
// CHECK: %cst_0 = constant dense<6> : tensor<4xi32> // CHECK: %[[CST_0:.*]] = constant dense<6> : tensor<4xi32>
// CHECK: %cst_1 = constant dense<5> : tensor<4xi32> // CHECK: %[[CST_1:.*]] = constant dense<5> : tensor<4xi32>
// CHECK: %cst_2 = constant dense<2> : tensor<4xi32> // CHECK: %[[CST_2:.*]] = constant dense<2> : tensor<4xi32>
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32> %5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32> %6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
@ -54,10 +54,10 @@ func @sub_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
%2 = constant dense< 3.5> : tensor<4xf32> %2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32> %3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<3.000000e+00> : tensor<f32> // CHECK: %[[CST:.*]] = constant dense<3.000000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<5.000000e+00> : tensor<4xf32> // CHECK: %[[CST_0:.*]] = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<2.000000e+00> : tensor<4xf32> // CHECK: %[[CST_1:.*]] = constant dense<2.000000e+00> : tensor<4xf32>
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32> // CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32>
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32> %5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32> %6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -75,10 +75,10 @@ func @sub_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%2 = constant dense< 4> : tensor<4xi32> %2 = constant dense< 4> : tensor<4xi32>
%3 = constant dense<-2> : tensor<4xi32> %3 = constant dense<-2> : tensor<4xi32>
// CHECK: %cst = constant dense<7> : tensor<i32> // CHECK: %[[CST:.*]] = constant dense<7> : tensor<i32>
// CHECK: %cst_0 = constant dense<10> : tensor<4xi32> // CHECK: %[[CST_0:.*]] = constant dense<10> : tensor<4xi32>
// CHECK: %cst_1 = constant dense<3> : tensor<4xi32> // CHECK: %[[CST_1:.*]] = constant dense<3> : tensor<4xi32>
// CHECK: %cst_2 = constant dense<6> : tensor<4xi32> // CHECK: %[[CST_2:.*]] = constant dense<6> : tensor<4xi32>
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32> %5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32> %6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
@ -96,10 +96,10 @@ func @mul_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
%2 = constant dense< 3.5> : tensor<4xf32> %2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32> %3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<6.750000e+00> : tensor<f32> // CHECK: %[[CST:.*]] = constant dense<6.750000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<-2.250000e+00> : tensor<4xf32> // CHECK: %[[CST_0:.*]] = constant dense<-2.250000e+00> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<5.250000e+00> : tensor<4xf32> // CHECK: %[[CST_1:.*]] = constant dense<5.250000e+00> : tensor<4xf32>
// CHECK: %cst_2 = constant dense<-1.750000e+00> : tensor<4xf32> // CHECK: %[[CST_2:.*]] = constant dense<-1.750000e+00> : tensor<4xf32>
%5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32> %5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32> %6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -170,8 +170,8 @@ func @add_dense_splat_int() -> tensor<4xi32> {
return %2 : tensor<4xi32> return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32> // CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %cst // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @add_splat_dense_int // CHECK-LABEL: @add_splat_dense_int
@ -183,8 +183,8 @@ func @add_splat_dense_int() -> tensor<4xi32> {
return %2 : tensor<4xi32> return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32> // CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %cst // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @add_dense_dense_int_same_shape // CHECK-LABEL: @add_dense_dense_int_same_shape
@ -196,8 +196,8 @@ func @add_dense_dense_int_same_shape() -> tensor<4xi32> {
return %2 : tensor<4xi32> return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[5, 22, -2, 98]> : tensor<4xi32> // CHECK: %[[CST:.*]] = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
// CHECK: return %cst // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @add_dense_dense_int_trailing_dim // CHECK-LABEL: @add_dense_dense_int_trailing_dim
@ -212,10 +212,10 @@ func @add_dense_dense_int_trailing_dim() -> (tensor<2x2xi32>, tensor<2x2x2xi32>,
return %0, %1, %2 : tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32> return %0, %1, %2 : tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32>
// CHECK: %cst = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32> // CHECK: %[[CST:.*]] = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
// CHECK: %cst_0 = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32> // CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
// CHECK: %cst_1 = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32> // CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
// CHECK: return %cst, %cst_0, %cst_1 // CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]]
} }
// CHECK-LABEL: @add_dense_dense_int_mixing_1_n // CHECK-LABEL: @add_dense_dense_int_mixing_1_n
@ -226,8 +226,8 @@ func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> {
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> %0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32> return %0 : tensor<2x2xi32>
// CHECK: %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32> // CHECK: %[[CST:.*]] = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
// CHECK: return %cst // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @add_dense_splat_float // CHECK-LABEL: @add_dense_splat_float
@ -239,8 +239,8 @@ func @add_dense_splat_float() -> tensor<4xf32> {
return %2 : tensor<4xf32> return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32> // CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %cst // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @add_splat_dense_float // CHECK-LABEL: @add_splat_dense_float
@ -252,8 +252,8 @@ func @add_splat_dense_float() -> tensor<4xf32> {
return %2 : tensor<4xf32> return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32> // CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %cst // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @add_dense_dense_float_same_shape // CHECK-LABEL: @add_dense_dense_float_same_shape
@ -265,8 +265,8 @@ func @add_dense_dense_float_same_shape() -> (tensor<4xf32>) {
return %2 : tensor<4xf32> return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32> // CHECK: %[[CST:.*]] = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
// CHECK: return %cst // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @add_dense_dense_float_trailing_dim // CHECK-LABEL: @add_dense_dense_float_trailing_dim
@ -281,10 +281,10 @@ func @add_dense_dense_float_trailing_dim() -> (tensor<2x2xf32>, tensor<2x2x2xf32
return %0, %1, %2 : tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32> return %0, %1, %2 : tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32> // CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
// CHECK: %cst_0 = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32> // CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
// CHECK: %cst_1 = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32> // CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
// CHECK: return %cst, %cst_0, %cst_1 // CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]]
} }
// CHECK-LABEL: @add_dense_dense_float_mixfng_1_n // CHECK-LABEL: @add_dense_dense_float_mixfng_1_n
@ -296,24 +296,24 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32> // CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
// CHECK: return %cst // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @rank // CHECK-LABEL: @rank
func @rank() -> tensor<1xi32> { func @rank() -> tensor<1xi32> {
%cst = constant dense<[[1], [2]]> : tensor<2x1xi32> %cst = constant dense<[[1], [2]]> : tensor<2x1xi32>
// CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32> // CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
%0 = "tfl.rank"(%cst) : (tensor<2x1xi32>) -> tensor<1xi32> %0 = "tfl.rank"(%cst) : (tensor<2x1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32> return %0 : tensor<1xi32>
} }
// CHECK-LABEL: @rank_input_known_rank // CHECK-LABEL: @rank_input_known_rank
func @rank_input_known_rank(%arg0 : tensor<2x1xi32>) -> tensor<1xi32> { func @rank_input_known_rank(%arg0 : tensor<2x1xi32>) -> tensor<1xi32> {
// CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32> // CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
%0 = "tfl.rank"(%arg0) : (tensor<2x1xi32>) -> tensor<1xi32> %0 = "tfl.rank"(%arg0) : (tensor<2x1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32> return %0 : tensor<1xi32>
} }
@ -323,8 +323,8 @@ func @reshape() -> tensor<4xi32> {
%input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> %input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%shape = constant dense<[4]> : tensor<1xi32> %shape = constant dense<[4]> : tensor<1xi32>
// CHECK: [[cst:%.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32> // CHECK: %[[CST:.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
%0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<4xi32> %0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32> return %0 : tensor<4xi32>
} }
@ -334,8 +334,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> {
%input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> %input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%shape = constant dense<[4]> : tensor<1xi32> %shape = constant dense<[4]> : tensor<1xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32> // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
%0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<?xi32> %0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32> return %0 : tensor<?xi32>
} }
@ -343,8 +343,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> {
// CHECK-LABEL: @pseudo_const // CHECK-LABEL: @pseudo_const
func @pseudo_const() -> tensor<i32> { func @pseudo_const() -> tensor<i32> {
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<i32> // CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
%0 = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> %0 = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
return %0 : tensor<i32> return %0 : tensor<i32>
} }
@ -356,8 +356,8 @@ func @range_int() -> tensor<?xi32> {
%cst_1 = constant dense<4> : tensor<i32> %cst_1 = constant dense<4> : tensor<i32>
%cst_2 = constant dense<1> : tensor<i32> %cst_2 = constant dense<1> : tensor<i32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32> // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32> %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
return %0 : tensor<?xi32> return %0 : tensor<?xi32>
} }
@ -368,8 +368,8 @@ func @range_float() -> tensor<?xf32> {
%cst_1 = constant dense<4.0> : tensor<f32> %cst_1 = constant dense<4.0> : tensor<f32>
%cst_2 = constant dense<1.0> : tensor<f32> %cst_2 = constant dense<1.0> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32> // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32> %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
@ -381,8 +381,8 @@ func @range_float_neg_delta() -> tensor<?xf32> {
%cst_1 = constant dense<-4.0> : tensor<f32> %cst_1 = constant dense<-4.0> : tensor<f32>
%cst_2 = constant dense<-1.0> : tensor<f32> %cst_2 = constant dense<-1.0> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32> // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32> %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
@ -393,8 +393,8 @@ func @range_float_nonzero_base() -> tensor<?xf32> {
%cst_1 = constant dense<7.0> : tensor<f32> %cst_1 = constant dense<7.0> : tensor<f32>
%cst_2 = constant dense<1.5> : tensor<f32> %cst_2 = constant dense<1.5> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32> // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32> %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
@ -414,8 +414,8 @@ func @transpose_1d() -> tensor<3xi32> {
%cst = constant dense<[1, 2, 3]> : tensor<3xi32> %cst = constant dense<[1, 2, 3]> : tensor<3xi32>
%cst_perm = constant dense<0> : tensor<1xi32> %cst_perm = constant dense<0> : tensor<1xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32> // CHECK: %[[CST:.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32> %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32> return %0 : tensor<3xi32>
} }
@ -425,8 +425,8 @@ func @transpose_dynamic() -> tensor<?xi32> {
%cst = constant dense<[1, 2, 3]> : tensor<3xi32> %cst = constant dense<[1, 2, 3]> : tensor<3xi32>
%cst_perm = constant dense<0> : tensor<1xi32> %cst_perm = constant dense<0> : tensor<1xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32> // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<?xi32> %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32> return %0 : tensor<?xi32>
} }
@ -436,8 +436,8 @@ func @transpose_2d() -> tensor<2x2xi32> {
%cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> %cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
%cst_perm = constant dense<[1, 0]> : tensor<2xi32> %cst_perm = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32> // CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32> return %0 : tensor<2x2xi32>
} }
@ -447,8 +447,8 @@ func @transpose_2d_identity() -> tensor<2x2xi32> {
%cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> %cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
%cst_perm = constant dense<[0, 1]> : tensor<2xi32> %cst_perm = constant dense<[0, 1]> : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32> // CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32> return %0 : tensor<2x2xi32>
} }
@ -460,8 +460,8 @@ func @transpose_3d() -> tensor<4x2x3xi32> {
%cst = constant dense<[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]> : tensor<2x3x4xi32> %cst = constant dense<[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]> : tensor<2x3x4xi32>
%cst_perm = constant dense<[2, 0, 1]> : tensor<3xi32> %cst_perm = constant dense<[2, 0, 1]> : tensor<3xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32> // CHECK: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x2x3xi32> %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x2x3xi32>
return %0 : tensor<4x2x3xi32> return %0 : tensor<4x2x3xi32>
} }
@ -473,8 +473,8 @@ func @ConstantFoldBinaryOpDynamicOutput() -> tensor<?xi32> {
%87 = "tfl.sub"(%cst_0, %cst) {fused_activation_function = "NONE"} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32> %87 = "tfl.sub"(%cst_0, %cst) {fused_activation_function = "NONE"} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
return %87 : tensor<?xi32> return %87 : tensor<?xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32> // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @add_dense_dense_int_same_shape_dynamic // CHECK-LABEL: @add_dense_dense_int_same_shape_dynamic
@ -486,8 +486,8 @@ func @add_dense_dense_int_same_shape_dynamic() -> tensor<?xi32> {
return %2 : tensor<?xi32> return %2 : tensor<?xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32> // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]] // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @concat_2_tensors_1_empty // CHECK-LABEL: @concat_2_tensors_1_empty
@ -497,8 +497,8 @@ func @concat_2_tensors_1_empty() -> tensor<2xi32> {
%3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<0xi32>) -> tensor<2xi32> %3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<0xi32>) -> tensor<2xi32>
return %3 : tensor<2xi32> return %3 : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<2xi32> // CHECK: %[[CST:.*]] = constant dense<1> : tensor<2xi32>
// CHECK: return [[cst]] : tensor<2xi32> // CHECK: return %[[CST]] : tensor<2xi32>
} }
// CHECK-LABEL: @concat_3_tensors_1_empty // CHECK-LABEL: @concat_3_tensors_1_empty
@ -509,7 +509,7 @@ func @concat_3_tensors_1_empty() -> tensor<?xi32> {
%3 = "tfl.concatenation"(%0, %1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<0xi32>) -> tensor<?xi32> %3 = "tfl.concatenation"(%0, %1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<0xi32>) -> tensor<?xi32>
return %3 : tensor<?xi32> return %3 : tensor<?xi32>
// CHECK: %0 = "tfl.concatenation"(%cst, %cst) {axis = 0 : i32, fused_activation_function = "NONE"} // CHECK: %0 = "tfl.concatenation"(%[[CST]], %[[CST]]) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: return %0 : tensor<?xi32> // CHECK: return %0 : tensor<?xi32>
} }
@ -520,10 +520,10 @@ func @concatConstantTensorsFirstDim() -> tensor<2x2x3xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<2x2x3xi32> %0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<2x2x3xi32>
return %0 : tensor<2x2x3xi32> return %0 : tensor<2x2x3xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32> // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32>
// CHECK-NOT: constant-dense // CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation" // CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]] // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @concatConstantTensorsMiddleDim // CHECK-LABEL: @concatConstantTensorsMiddleDim
@ -533,10 +533,10 @@ func @concatConstantTensorsMiddleDim() -> tensor<1x4x3xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x4x3xi32> %0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x4x3xi32>
return %0 : tensor<1x4x3xi32> return %0 : tensor<1x4x3xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32> // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32>
// CHECK-NOT: constant-dense // CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation" // CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]] // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @concatConstantTensorsLastDim // CHECK-LABEL: @concatConstantTensorsLastDim
@ -546,10 +546,10 @@ func @concatConstantTensorsLastDim() -> tensor<1x2x6xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 2 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x2x6xi32> %0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 2 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x2x6xi32>
return %0 : tensor<1x2x6xi32> return %0 : tensor<1x2x6xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32> // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32>
// CHECK-NOT: constant-dense // CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation" // CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]] // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @div_dense_dense_float_mixfng_1_n // CHECK-LABEL: @div_dense_dense_float_mixfng_1_n
@ -561,8 +561,8 @@ func @div_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32> // CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
// CHECK: return %cst // CHECK: return %[[CST]]
} }
// CHECK-LABEL: @div_dense_different_rank // CHECK-LABEL: @div_dense_different_rank
@ -574,6 +574,6 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> {
return %0 : tensor<1x2x2xf32> return %0 : tensor<1x2x2xf32>
// CHECK: %cst = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32> // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
// CHECK: return %cst // CHECK: return %[[CST]]
} }

View File

@ -1048,6 +1048,15 @@ func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> // CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
} }
func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i64> } : () -> tensor<i64>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i64>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concatv2I64Axis
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
}
func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> { func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
%0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32> %0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>

View File

@ -65,7 +65,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: opcode_index: 1, // CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 2, 1 ], // CHECK-NEXT: inputs: [ 2, 1 ],
// CHECK-NEXT: outputs: [ 3 ], // CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: custom_options: [ 105, 110, 116, 95, 97, 116, 116, 114, 0, 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 2, 33, 43, 2, 1, 2, 11, 2, 20, 4, 4, 36, 1 ] // CHECK-NEXT: custom_options: [ 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 105, 110, 116, 95, 97, 116, 116, 114, 0, 2, 42, 11, 2, 1, 2, 20, 2, 20, 4, 4, 36, 1 ]
// CHECK-NEXT: }, { // CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 2, // CHECK-NEXT: opcode_index: 2,
// CHECK-NEXT: inputs: [ 3 ], // CHECK-NEXT: inputs: [ 3 ],

View File

@ -19,6 +19,16 @@ func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> (tensor<2xf32>,t
// CHECK-NEXT: return %[[split]]#0, %[[split]]#1 // CHECK-NEXT: return %[[split]]#0, %[[split]]#1
} }
// CHECK-LABEL: RemoveTrival
func @RemoveTrival(%arg0: tensor<384x512x!quant.uniform<i8:f32, 1.0:-128>>, %arg1: tensor<128x512x!quant.uniform<i8<-127:127>:f32, 1.0>>, %arg2: none) -> tensor<384x128x!quant.uniform<i8:f32, 2.0>> {
%1 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<384x512x!quant.uniform<i8:f32, 1.0:-128>>, tensor<128x512x!quant.uniform<i8<-127:127>:f32, 1.0>>, none) -> tensor<384x128x!quant.uniform<i8:f32, 1.0>>
%2 = "tfl.quantize"(%1) {qtype = tensor<384x128x!quant.uniform<i8:f32, 2.0>>} : (tensor<384x128x!quant.uniform<i8:f32, 1.0>>) -> tensor<384x128x!quant.uniform<i8:f32, 2.0>>
return %2 : tensor<384x128x!quant.uniform<i8:f32, 2.0>>
// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"{{.*}} -> tensor<384x128x!quant.uniform<i8:f32, 2.000000e+00>>
// CHECK-NEXT: return %[[fc]]
}
func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> { func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
%cst = constant dense<[1, 1001]> : tensor<2xi32> %cst = constant dense<[1, 1001]> : tensor<2xi32>
%0 = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>> %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>

View File

@ -48,7 +48,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
quant_specs.default_ranges.second.hasValue()) { quant_specs.default_ranges.second.hasValue()) {
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass( pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
quant_specs.default_ranges.first.getValueOr(0.0), quant_specs.default_ranges.first.getValueOr(0.0),
quant_specs.default_ranges.second.getValueOr(0.0))); quant_specs.default_ranges.second.getValueOr(0.0),
quant_specs.IsSignedInferenceType()));
pass_manager->addPass(mlir::TFL::CreateQuantizePass()); pass_manager->addPass(mlir::TFL::CreateQuantizePass());
pass_manager->addPass( pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));

View File

@ -46,8 +46,11 @@ namespace {
class DefaultQuantParamsPass class DefaultQuantParamsPass
: public PassWrapper<DefaultQuantParamsPass, FunctionPass> { : public PassWrapper<DefaultQuantParamsPass, FunctionPass> {
public: public:
explicit DefaultQuantParamsPass(double default_min, double default_max) explicit DefaultQuantParamsPass(double default_min, double default_max,
: default_min_(default_min), default_max_(default_max) {} bool is_signed)
: default_min_(default_min),
default_max_(default_max),
is_signed_(is_signed) {}
void runOnFunction() override; void runOnFunction() override;
@ -82,6 +85,7 @@ class DefaultQuantParamsPass
double default_min_; double default_min_;
double default_max_; double default_max_;
bool is_signed_;
quant::QuantParams default_quant_params_; quant::QuantParams default_quant_params_;
}; };
} // namespace } // namespace
@ -214,15 +218,16 @@ quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
default_quant_params_ = quant::fakeQuantAttrsToType( default_quant_params_ = quant::fakeQuantAttrsToType(
builder.getUnknownLoc(), builder.getUnknownLoc(),
/*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false, /*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false,
builder.getF32Type()); builder.getF32Type(), is_signed_);
} }
return default_quant_params_; return default_quant_params_;
} }
// Creates an instance of the default quant parameters pass. // Creates an instance of the default quant parameters pass.
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass( std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max) { double default_min, double default_max, bool is_signed) {
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max); return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max,
is_signed);
} }
// Registers this pass with default values, only for test // Registers this pass with default values, only for test
@ -230,7 +235,8 @@ static PassRegistration<DefaultQuantParamsPass> pass(
"tfl-default-quant", "tfl-default-quant",
"Apply quantization with default quantization parameter", [] { "Apply quantization with default quantization parameter", [] {
return CreateDefaultQuantParamsPass(/*default_min=*/-1.0, return CreateDefaultQuantParamsPass(/*default_min=*/-1.0,
/*default_max=*/1.0); /*default_max=*/1.0,
/*is_signed=*/false);
}); });
} // namespace TFL } // namespace TFL

View File

@ -321,7 +321,8 @@ void DenseToSparse::runOnFunction() {
if (result.needs_densify) { if (result.needs_densify) {
const auto value = op->getOperand(operand); const auto value = op->getOperand(operand);
auto densify = builder.create<DensifyOp>(op->getLoc(), value); auto densify =
builder.create<DensifyOp>(op->getLoc(), value.getType(), value);
value.replaceAllUsesWith(densify); value.replaceAllUsesWith(densify);
densify.setOperand(value); densify.setOperand(value);
} }

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
@ -202,6 +203,26 @@ LogicalResult ConvertTFConcatOp::matchAndRewrite(
return success(); return success();
} }
// Converts any IntegerAttr to an IntegerAttr of an i32 type.
// The value won't change in the new attribute, but if the value is out of
// the bound of i32, the function returns a failure.
LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
if (attr.getType().isInteger(/*width=*/32)) {
*attr_i32 = attr;
return success();
}
int64_t value = attr.getInt();
if (value > std::numeric_limits<int>::max() ||
value < std::numeric_limits<int>::min()) {
return failure();
}
*attr_i32 = IntegerAttr::get(
IntegerType::get(/*width=*/32, attr.getContext()), value);
return success();
}
LogicalResult ConvertTFConcatV2Op::matchAndRewrite( LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const { Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatV2Op>(op); auto tf_concat_op = cast<TF::ConcatV2Op>(op);
@ -211,12 +232,16 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
// Extract axis attribute from constant axis tensor // Extract axis attribute from constant axis tensor
ElementsAttr axis; ElementsAttr axis;
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure(); if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure();
IntegerAttr axis_int = ExtractSingleElementAsInteger(axis);
// "axis" operand could be a i64 tensor. Resolve it here.
IntegerAttr axis_i32;
if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure();
StringAttr fused_activation_function = StringAttr fused_activation_function =
StringAttr::get("NONE", rewriter.getContext()); StringAttr::get("NONE", rewriter.getContext());
rewriter.replaceOpWithNewOp<ConcatenationOp>( rewriter.replaceOpWithNewOp<ConcatenationOp>(
op, output_type, values, ExtractSingleElementAsInteger(axis), op, output_type, values, axis_i32, fused_activation_function);
fused_activation_function);
return success(); return success();
} }

View File

@ -859,6 +859,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
target.addLegalOp<ConstantOp>(); target.addLegalOp<ConstantOp>();
target.addLegalOp<FuncOp>(); target.addLegalOp<FuncOp>();
target.addLegalOp<ReturnOp>(); target.addLegalOp<ReturnOp>();
target.addLegalOp<TFL::CustomOp>();
// Register fused LSTM/RNN ops as legal. // Register fused LSTM/RNN ops as legal.
target.addLegalOp<TFL::LSTMOp>(); target.addLegalOp<TFL::LSTMOp>();
target.addLegalOp<TFL::UnidirectionalSequenceLSTMOp>(); target.addLegalOp<TFL::UnidirectionalSequenceLSTMOp>();

View File

@ -76,7 +76,7 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass();
// Creates an instance of the TensorFlow Lite dialect pass to add default // Creates an instance of the TensorFlow Lite dialect pass to add default
// quantization parameters. // quantization parameters.
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass( std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max); double default_min, double default_max, bool is_signed);
// Creates an instance of the TensorFlow Lite dialect pass to convert dense // Creates an instance of the TensorFlow Lite dialect pass to convert dense
// tensor to sparse format. // tensor to sparse format.

View File

@ -125,6 +125,7 @@ void PostQuantizePass::runOnFunction() {
auto func = getFunction(); auto func = getFunction();
auto* ctx = func.getContext(); auto* ctx = func.getContext();
TFL::populateWithGenerated(ctx, &patterns); TFL::populateWithGenerated(ctx, &patterns);
patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
applyPatternsAndFoldGreedily(func, patterns); applyPatternsAndFoldGreedily(func, patterns);
if (!emit_quant_adaptor_ops_) { if (!emit_quant_adaptor_ops_) {

View File

@ -70,6 +70,7 @@ class PrepareQuantizePass
: public PassWrapper<PrepareQuantizePass, FunctionPass> { : public PassWrapper<PrepareQuantizePass, FunctionPass> {
public: public:
// Constructor used by the PassRegistration and enforce uint8 quantization. // Constructor used by the PassRegistration and enforce uint8 quantization.
// This is only used by test.
explicit PrepareQuantizePass() { explicit PrepareQuantizePass() {
if (quantize_signed) if (quantize_signed)
quant_specs_.inference_type = tensorflow::DT_QINT8; quant_specs_.inference_type = tensorflow::DT_QINT8;
@ -257,15 +258,16 @@ void PrepareQuantizePass::runOnFunction() {
// convert all of them to signed. // convert all of them to signed.
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
bool is_signed = quant_specs_.IsSignedInferenceType(); bool is_signed = quant_specs_.IsSignedInferenceType();
int bit_width = quant_specs_.GetQuantizationTypeWidth();
if (is_signed) { if (is_signed) {
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx); patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
// Convert quant stats to int8 quantization parameters. // Convert quant stats to int8 quantization parameters.
// Currently, only activation stats are imported, so narrow_range = false. // Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(8, false, true, ctx); patterns.insert<PrepareQuantStats>(bit_width, false, true, ctx);
} else { } else {
// Convert quant stats to uint8 quantization parameters. // Convert quant stats to uint8 quantization parameters.
// Currently, only activation stats are imported, so narrow_range = false. // Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(8, false, false, ctx); patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
} }
applyPatternsAndFoldGreedily(func, patterns); applyPatternsAndFoldGreedily(func, patterns);

View File

@ -0,0 +1,41 @@
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
package(licenses = ["notice"])
tf_python_pybind_extension(
name = "mlir_wrapper",
srcs = [
"attrs.cc",
"basic_classes.cc",
"builders.cc",
"mlir_wrapper.cc",
"mlir_wrapper.h",
"ops.cc",
"types.cc",
],
module_name = "mlir_wrapper",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@pybind11",
],
)
tf_python_pybind_extension(
name = "filecheck_wrapper",
srcs = ["filecheck_wrapper.cc"],
module_name = "filecheck_wrapper",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
"@llvm-project//llvm:support",
"@pybind11",
],
)

View File

@ -0,0 +1,25 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
void init_attrs(py::module& m) {
py::class_<mlir::Attribute>(m, "Attribute");
py::class_<mlir::IntegerAttr, mlir::Attribute>(m, "IntegerAttr")
.def("get",
py::overload_cast<mlir::Type, int64_t>(&mlir::IntegerAttr::get));
}

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/FileCheck.h"
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Region.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
void init_basic_classes(py::module& m) {
py::class_<mlir::MLIRContext>(m, "MLIRContext").def(py::init<>());
py::class_<mlir::Location>(m, "Location");
py::class_<mlir::UnknownLoc>(m, "UnknownLoc")
.def("get", &mlir::UnknownLoc::get);
py::class_<mlir::Region>(m, "Region")
.def("back", &mlir::Region::back, py::return_value_policy::reference)
.def("front", &mlir::Region::front, py::return_value_policy::reference)
.def("add_block", [](mlir::Region& r) { r.push_back(new mlir::Block); })
.def("push_back", &mlir::Region::push_back)
.def("size", [](mlir::Region& r) { return r.getBlocks().size(); })
.def("front", &mlir::Region::front, py::return_value_policy::reference);
py::class_<mlir::Block::iterator>(m, "Block_Iterator");
py::class_<mlir::Block>(m, "Block")
.def("new", ([]() { return new mlir::Block; }),
py::return_value_policy::reference)
.def("end", &mlir::Block::end)
.def("addArgument", &mlir::Block::addArgument);
py::class_<mlir::Value>(m, "Value").def("getType", &mlir::Value::getType);
py::class_<mlir::OpResult, mlir::Value>(m, "OpResult");
py::class_<mlir::BlockArgument, mlir::Value>(m, "BlockArgument");
}

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/Builders.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
void init_builders(py::module& m) {
py::class_<mlir::Builder>(m, "Builder")
.def(py::init<mlir::MLIRContext*>())
.def("getFunctionType",
[](mlir::Builder& b, std::vector<mlir::Type> inputs,
std::vector<mlir::Type> outputs) {
return b.getFunctionType(llvm::ArrayRef<mlir::Type>(inputs),
llvm::ArrayRef<mlir::Type>(outputs));
});
py::class_<mlir::OpBuilder>(m, "OpBuilder")
.def(py::init<mlir::MLIRContext*>())
.def(py::init<mlir::Region&>())
.def(py::init<mlir::Operation*>())
.def(py::init<mlir::Block*, mlir::Block::iterator>())
.def("getUnknownLoc", &mlir::OpBuilder::getUnknownLoc)
.def("setInsertionPoint",
py::overload_cast<mlir::Block*, mlir::Block::iterator>(
&mlir::OpBuilder::setInsertionPoint))
.def("saveInsertionPoint", &mlir::OpBuilder::saveInsertionPoint)
.def("restoreInsertionPoint", &mlir::OpBuilder::restoreInsertionPoint)
.def(
"createOperation",
[](mlir::OpBuilder& opb, mlir::OperationState& state) {
return opb.createOperation(state);
},
py::return_value_policy::reference)
.def("getContext", &mlir::OpBuilder::getContext,
py::return_value_policy::reference);
py::class_<mlir::OpBuilder::InsertPoint>(m, "OpBuilder_InsertionPoint")
.def("getBlock", &mlir::OpBuilder::InsertPoint::getBlock);
}

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/FileCheck.h"
#include "llvm/Support/SourceMgr.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
PYBIND11_MODULE(filecheck_wrapper, m) {
m.def("check", [](std::string input, std::string check) {
llvm::FileCheckRequest fcr;
llvm::FileCheck fc(fcr);
llvm::SourceMgr SM = llvm::SourceMgr();
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
llvm::SMLoc());
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(check),
llvm::SMLoc());
llvm::Regex regex = fc.buildCheckPrefixRegex();
fc.readCheckFile(SM, llvm::StringRef(check), regex);
return fc.checkInput(SM, llvm::StringRef(input));
});
}

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
PYBIND11_MODULE(mlir_wrapper, m) {
m.def("registerDialects", []() {
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
});
init_basic_classes(m);
init_types(m);
init_builders(m);
init_ops(m);
init_attrs(m);
}

View File

@ -13,19 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
//===- dialect_static_registration.cc -------------------------------------===// #ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
// #define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
// This file registers the RuntimeFallbackDialect.
//
//===----------------------------------------------------------------------===//
#include "tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h" #include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace mlir { namespace py = pybind11;
namespace tfd {
// Static initialization for dialect registration. void init_basic_classes(py::module& m);
static DialectRegistration<RuntimeFallbackDialect> tfd_registration; void init_types(py::module& m);
void init_builders(py::module& m);
void init_ops(py::module& m);
void init_attrs(py::module& m);
} // namespace tfd #endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
} // namespace mlir

View File

@ -0,0 +1,194 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
void init_ops(py::module& m) {
py::class_<mlir::Operation, std::unique_ptr<mlir::Operation, py::nodelete>>(
m, "Operation")
.def("getRegion", &mlir::Operation::getRegion,
py::return_value_policy::reference)
.def("getResult", &mlir::Operation::getResult)
.def("dump", &mlir::Operation::dump)
.def("getNumResults", &mlir::Operation::getNumResults);
py::class_<mlir::OperationState>(m, "OperationState")
.def(py::init([](mlir::Location loc, std::string name) {
return mlir::OperationState(loc, llvm::StringRef(name));
}))
.def("addTypes",
[](mlir::OperationState& state, std::vector<mlir::Type> tys) {
state.addTypes(mlir::ArrayRef<mlir::Type>(tys));
})
.def("addOperands",
[](mlir::OperationState& os, std::vector<mlir::Value> ops) {
os.addOperands(mlir::ArrayRef<mlir::Value>(ops));
})
.def("addRegion", py::overload_cast<>(&mlir::OperationState::addRegion),
py::return_value_policy::reference);
py::class_<mlir::ModuleOp>(m, "ModuleOp")
.def("create",
[](mlir::Location loc) { return mlir::ModuleOp::create(loc); })
.def("push_back",
[](mlir::ModuleOp& m, mlir::FuncOp f) { m.push_back(f); })
.def("dump", &mlir::ModuleOp::dump)
.def("getAsStr", [](mlir::ModuleOp& m) {
std::string str;
llvm::raw_string_ostream os(str);
m.print(os);
return os.str();
});
py::class_<mlir::FuncOp>(m, "FuncOp")
.def("create",
[](mlir::Location location, std::string name,
mlir::FunctionType type) {
auto func = mlir::FuncOp::create(location, name, type);
func.addEntryBlock();
return func;
})
.def(
"getBody",
[](mlir::FuncOp& f) -> mlir::Region& { return f.getBody(); },
py::return_value_policy::reference)
.def("getArguments",
[](mlir::FuncOp& f) { return f.getArguments().vec(); })
.def("getName", [](mlir::FuncOp& f) { return f.getName().str(); })
.def("getType", &mlir::FuncOp::getType);
py::class_<mlir::ReturnOp>(m, "ReturnOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc,
std::vector<mlir::Value> values) -> mlir::Operation* {
return opb
.create<mlir::ReturnOp>(loc,
mlir::ArrayRef<mlir::Value>(values))
.getOperation();
});
// mlir::TF::AddOp
py::class_<mlir::TF::AddV2Op>(m, "Tf_AddV2Op")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::AddV2Op>(loc, x, y).getOperation();
});
py::class_<mlir::TF::AnyOp>(m, "Tf_AnyOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value input,
mlir::Value reduction_indices,
bool keep_dims = false) -> mlir::Operation* {
return opb
.create<mlir::TF::AnyOp>(loc, opb.getI1Type(), input,
reduction_indices, keep_dims)
.getOperation();
});
// mlir::TF::ConstOp
py::class_<mlir::TF::ConstOp>(m, "Tf_ConstOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc,
mlir::Attribute value) -> mlir::Operation* {
return opb.create<mlir::TF::ConstOp>(loc, value).getOperation();
});
// mlir::TF::EqualOp
py::class_<mlir::TF::EqualOp>(m, "Tf_EqualOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb
.create<mlir::TF::EqualOp>(loc, x, y, opb.getBoolAttr(true))
.getOperation();
});
// mlir::TF::GreaterEqualOp
py::class_<mlir::TF::GreaterEqualOp>(m, "Tf_GreaterEqualOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::GreaterEqualOp>(loc, x, y)
.getOperation();
});
// mlir::TF::GreaterOp
py::class_<mlir::TF::GreaterOp>(m, "Tf_GreaterOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::GreaterOp>(loc, x, y).getOperation();
});
// mlir::TF::LegacyCallOp
py::class_<mlir::TF::LegacyCallOp>(m, "Tf_LegacyCallOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc,
std::vector<mlir::Type> output, std::vector<mlir::Value> args,
std::string f) -> mlir::Operation* {
return opb
.create<mlir::TF::LegacyCallOp>(
loc, mlir::ArrayRef<mlir::Type>(output),
mlir::ArrayRef<mlir::Value>(args), mlir::StringRef(f))
.getOperation();
});
// mlir::TF::LessEqualOp
py::class_<mlir::TF::LessEqualOp>(m, "Tf_LessEqualOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::LessEqualOp>(loc, x, y).getOperation();
});
// mlir::TF::LessOp
py::class_<mlir::TF::LessOp>(m, "Tf_LessOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::LessOp>(loc, x, y).getOperation();
});
// mlir::TF::NegOp
py::class_<mlir::TF::NegOp>(m, "Tf_NegOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc,
mlir::Value x) -> mlir::Operation* {
return opb.create<mlir::TF::NegOp>(loc, x).getOperation();
});
py::class_<mlir::TF::NotEqualOp>(m, "Tf_NotEqualOp")
.def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) {
return opb
.create<mlir::TF::NotEqualOp>(
loc, x, y, mlir::BoolAttr::get(true, opb.getContext()))
.getOperation();
});
// mlir::TF::SubOp
py::class_<mlir::TF::SubOp>(m, "Tf_SubOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::SubOp>(loc, x, y).getOperation();
});
}

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
void init_types(py::module& m) {
// Type
py::class_<mlir::Type> Type(m, "Type");
Type.def("getKind", &mlir::Type::getKind);
// Type Enums
py::enum_<mlir::StandardTypes::Kind>(Type, "StandardTypes_Kind")
.value("BF16", mlir::StandardTypes::BF16);
// Type Sub-classes
py::class_<mlir::FunctionType, mlir::Type>(m, "FunctionType")
.def("getResults",
[](mlir::FunctionType& ft) { return ft.getResults().vec(); });
py::class_<mlir::FloatType, mlir::Type>(m, "FloatType")
.def("get", &mlir::FloatType::get);
py::class_<mlir::IntegerType, mlir::Type>(m, "IntegerType")
.def("get", py::overload_cast<unsigned, mlir::MLIRContext*>(
&mlir::IntegerType::get));
py::class_<mlir::UnrankedTensorType, mlir::Type>(m, "UnrankedTensorType")
.def("get", &mlir::UnrankedTensorType::get);
py::class_<mlir::RankedTensorType, mlir::Type>(m, "RankedTensorType")
.def("get", [](std::vector<int64_t> shape, mlir::Type ty) {
return mlir::RankedTensorType::get(mlir::ArrayRef<int64_t>(shape), ty);
});
}

View File

@ -70,9 +70,9 @@ tool_dirs = config.mlir_tf_tools_dirs + [
] ]
tool_names = [ tool_names = [
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', 'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt', 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
'xla-opt' 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt'
] ]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs) llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -44,6 +44,7 @@ mlir_tf_tools_dirs = [
'tensorflow/compiler/mlir', 'tensorflow/compiler/mlir',
'tensorflow/compiler/mlir/lite', 'tensorflow/compiler/mlir/lite',
'tensorflow/compiler/mlir/tensorflow', 'tensorflow/compiler/mlir/tensorflow',
'tensorflow/compiler/mlir/tfjs',
'tensorflow/compiler/mlir/xla', 'tensorflow/compiler/mlir/xla',
'tensorflow/compiler/aot', 'tensorflow/compiler/aot',
'tensorflow/compiler/xla/service/mlir_gpu', 'tensorflow/compiler/xla/service/mlir_gpu',

View File

@ -36,7 +36,7 @@ filegroup(
"@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
], ],
) )
@ -1075,7 +1075,7 @@ genrule(
srcs = [ srcs = [
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
"@llvm-project//mlir:include/mlir/IR/OpBase.td", "@llvm-project//mlir:include/mlir/IR/OpBase.td",
"ir/tf_generated_ops.td", "ir/tf_generated_ops.td",
"ir/tf_op_base.td", "ir/tf_op_base.td",
@ -1140,6 +1140,7 @@ COMPILE_MLIR_UTIL_DEPS = [
"//tensorflow/compiler/mlir/xla:type_to_shape", "//tensorflow/compiler/mlir/xla:type_to_shape",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf", "//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/xla:xla_sink_constants_to_control_flow",
"//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -1278,6 +1279,7 @@ cc_library(
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
], ],
) )
@ -1292,6 +1294,7 @@ tf_cc_test(
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/protobuf/tpu:topology_proto_cc", "//tensorflow/core/protobuf/tpu:topology_proto_cc",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
], ],
) )

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
namespace mlir { namespace mlir {
namespace TFControlFlow { namespace TFControlFlow {

View File

@ -1765,7 +1765,7 @@ of corresponding 3-element vectors is cross-multiplied independently.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [AllTypesMatch<["input", "output"]>, NoSideEffect]> { def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> {
let summary = "An Op to sum inputs across replicated TPU instances."; let summary = "An Op to sum inputs across replicated TPU instances.";
let description = [{ let description = [{
@ -1789,7 +1789,7 @@ and `B, D, F, H` as group 1. Thus we get the outputs:
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_CumsumOp : TF_Op<"Cumsum", [AllTypesMatch<["x", "out"]>, NoSideEffect]> { def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
let summary = "Compute the cumulative sum of the tensor `x` along `axis`."; let summary = "Compute the cumulative sum of the tensor `x` along `axis`.";
let description = [{ let description = [{
@ -2907,6 +2907,8 @@ fill([2, 3], 9) ==> [[9, 9, 9]
return Verify(*this); return Verify(*this);
}]; }];
let hasFolder = 1;
let builders = [OpBuilder< let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value dims, Value value" "OpBuilder &builder, OperationState &result, Value dims, Value value"
>]; >];
@ -3625,31 +3627,6 @@ tf.imag(input) ==> [4.75, 5.75]
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
} }
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
let summary = [{
Create a copy of `x` with the updated specified rows 'i' with values 'v'.
}];
let description = [{
Creates a copy of tensor 'x' and updates the columns specified in tensor 'i'
with the values 'v'. Originally this function was mutative however for
compilation we make this operation create / operate on a copy.
}];
let arguments = (ins
TF_Tensor:$x,
I32Tensor:$i,
TF_Tensor:$v
);
let results = (outs
TF_Tensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_InvOp : TF_Op<"Inv", [NoSideEffect, SameOperandsAndResultType]> { def TF_InvOp : TF_Op<"Inv", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the reciprocal of x element-wise."; let summary = "Computes the reciprocal of x element-wise.";
@ -4350,7 +4327,7 @@ cublas.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [AllTypesMatch<["input", "band"]>, NoSideEffect]> { def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, TF_AllTypesMatch<["input", "band"]>]> {
let summary = [{ let summary = [{
Copy a tensor setting everything outside a central band in each innermost matrix to zero. Copy a tensor setting everything outside a central band in each innermost matrix to zero.
}]; }];
@ -6354,6 +6331,8 @@ If `x` and `y` are reals, this will return the floating-point division.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1;
} }
def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> { def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> {
@ -10352,6 +10331,33 @@ https://www.tensorflow.org/xla/operation_semantics#gather
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_XlaHostComputeOp : TF_Op<"XlaHostCompute", []> {
let summary = [{
A pseudo-op to represent host-side computation in an XLA program.
}];
let description = [{
}];
let arguments = (ins
Variadic<TF_Tensor>:$inputs,
StrArrayAttr:$ancestors,
TF_ShapeAttrArray:$shapes,
SymbolRefAttr:$shape_inference_graph,
StrAttr:$key,
DefaultValuedAttr<I64Attr, "1000000">:$cost_estimate_ns,
DefaultValuedAttr<I64Attr, "0">:$tpu_core
);
let results = (outs
Variadic<TF_Tensor>:$outputs
);
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
}
def TF_XlaKeyValueSortOp : TF_Op<"XlaKeyValueSort", [NoSideEffect]> { def TF_XlaKeyValueSortOp : TF_Op<"XlaKeyValueSort", [NoSideEffect]> {
let summary = "Wraps the XLA Sort operator, documented at"; let summary = "Wraps the XLA Sort operator, documented at";
@ -10400,6 +10406,24 @@ https://www.tensorflow.org/performance/xla/operation_semantics#pad
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_XlaRecvFromHostOp : TF_Op<"XlaRecvFromHost", []> {
let summary = "An op to receive a tensor from the host.";
let description = [{
}];
let arguments = (ins
TF_ShapeAttr:$shape,
StrAttr:$key
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedResultTypeAttr Toutput = TF_DerivedResultTypeAttr<0>;
}
def TF_XlaReduceOp : TF_Op<"XlaReduce", [NoSideEffect]> { def TF_XlaReduceOp : TF_Op<"XlaReduce", [NoSideEffect]> {
let summary = "Wraps the XLA Reduce operator, documented at"; let summary = "Wraps the XLA Reduce operator, documented at";
@ -10464,6 +10488,23 @@ i=0...N-1.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_XlaSendToHostOp : TF_Op<"XlaSendToHost", []> {
let summary = "An op to send a tensor to the host.";
let description = [{
}];
let arguments = (ins
TF_Tensor:$input,
StrAttr:$key
);
let results = (outs);
TF_DerivedOperandTypeAttr Tinput = TF_DerivedOperandTypeAttr<0>;
}
def TF_XlaSvdOp : TF_Op<"XlaSvd", [NoSideEffect]> { def TF_XlaSvdOp : TF_Op<"XlaSvd", [NoSideEffect]> {
let summary = [{ let summary = [{
Computes the eigen decomposition of a batch of self-adjoint matrices Computes the eigen decomposition of a batch of self-adjoint matrices
@ -10547,6 +10588,27 @@ def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> {
let summary = "A host-side computation called from a TPU device.";
let description = [{
}];
let arguments = (ins
Variadic<TF_Tensor>:$inputs,
StrAttr:$key,
DefaultValuedAttr<I64Attr, "0">:$tpu_core
);
let results = (outs
Variadic<TF_Tensor>:$outputs
);
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
}
def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", []> { def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", []> {
let summary = "An op that receives embeddng activations on the TPU."; let summary = "An op that receives embeddng activations on the TPU.";
@ -10605,3 +10667,44 @@ used to look up the program in the compilation cache.
TF_DerivedResultSizeAttr num_computations = TF_DerivedResultSizeAttr<1>; TF_DerivedResultSizeAttr num_computations = TF_DerivedResultSizeAttr<1>;
TF_DerivedOperandSizeAttr NumDynamicShapes = TF_DerivedOperandSizeAttr<0>; TF_DerivedOperandSizeAttr NumDynamicShapes = TF_DerivedOperandSizeAttr<0>;
} }
def TF__XlaRecvAtHostOp : TF_Op<"_XlaRecvAtHost", []> {
let summary = [{
A placeholder op to receive values from a running XLA computation.
}];
let description = [{
}];
let arguments = (ins
TF_StrTensor:$dynamic_key,
StrAttr:$key,
I64Attr:$device_ordinal
);
let results = (outs
Variadic<TF_Tensor>:$outputs
);
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
}
def TF__XlaSendFromHostOp : TF_Op<"_XlaSendFromHost", []> {
let summary = "A placeholder op to send values to a running XLA computation.";
let description = [{
}];
let arguments = (ins
Variadic<TF_Tensor>:$inputs,
TF_StrTensor:$dynamic_key,
StrAttr:$key,
I64Attr:$device_ordinal
);
let results = (outs);
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
}

View File

@ -23,7 +23,7 @@ limitations under the License.
#define TF_OP_BASE #define TF_OP_BASE
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffects.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -70,6 +70,16 @@ class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
"$_op.getOperand(" # opId # ").getType(), " "$_op.getOperand(" # opId # ").getType(), "
"$_op.getResult(" # resId # ").getType())">]>; "$_op.getResult(" # resId # ").getType())">]>;
class TF_AllTypesMatchPred<list<string> values> :
CPred<"TF::AreCastCompatible(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">;
class TF_AllTypesMatch<list<string> names> :
PredOpTrait<
"all of {" # StrJoin<names>.result # "} have dynamically equal types ",
TF_AllTypesMatchPred<
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TensorFlow op definitions // TensorFlow op definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -110,48 +110,6 @@ static inline bool HasRankAtMost(Value value, int64_t rank) {
return !type || type.getRank() <= rank; return !type || type.getRank() <= rank;
} }
// Returns true if the given pair of TensorFlow types can be cast to one
// another. In other words, a single run-time value is legal for both the types.
// For example, tensor<*xf32> and tensor<3xf32> are cast compatible.
static bool AreCastCompatible(Type a, Type b) {
if (TensorCastOp::areCastCompatible(a, b)) return true;
// Resource types may optionally contain subtypes information that does not
// match. Check subtypes compatibility when possible, otherwise treat them as
// compatible.
auto a_or_element_type = getElementTypeOrSelf(a);
auto b_or_element_type = getElementTypeOrSelf(b);
auto a_kind = a_or_element_type.getKind();
auto b_kind = b_or_element_type.getKind();
if (a_kind == TensorFlowTypes::RESOURCE &&
b_kind == TensorFlowTypes::RESOURCE) {
auto a_resource_type = a_or_element_type.dyn_cast<ResourceType>();
auto b_resource_type = b_or_element_type.dyn_cast<ResourceType>();
bool a_has_subtype = !a_resource_type.getSubtypes().empty();
bool b_has_subtype = !b_resource_type.getSubtypes().empty();
if (!a_has_subtype || !b_has_subtype) return true;
assert(a_resource_type.getSubtypes().size() <= 1 &&
"Resource type must have at most one subtype");
assert(b_resource_type.getSubtypes().size() <= 1 &&
"Resource type must have at most one subtype");
return TensorCastOp::areCastCompatible(
a_resource_type.getSubtypes().front(),
b_resource_type.getSubtypes().front());
}
// Variant types may optionally contain subtypes information that need not
// match. It is also not possible to compare subtypes for compatibility as
// their interpretation depends on the ops operating on them. So, accept all
// pairs of variant types.
return a_kind == TensorFlowTypes::VARIANT &&
b_kind == TensorFlowTypes::VARIANT;
}
static bool IsUnknownDimOrRank(int64_t dim_or_rank) { static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
return dim_or_rank == -1; return dim_or_rank == -1;
} }
@ -503,9 +461,10 @@ LogicalResult FoldOperandsPermutation(
namespace { namespace {
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant // Folder that returns LHS of an Arithmetic Op if the RHS is a constant
// known to be Identity (e.g X+0) // known to be Identity (e.g X+0)
template <typename OpT, template <
typename std::enable_if<llvm::is_one_of< typename OpT,
OpT, AddV2Op, SubOp, MulOp, DivOp>::value>::type * = nullptr> typename std::enable_if<llvm::is_one_of<
OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr>
OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
ArrayRef<Attribute> operands) { ArrayRef<Attribute> operands) {
auto result_op_type = arithmetic_op.getResult().getType(); auto result_op_type = arithmetic_op.getResult().getType();
@ -520,7 +479,8 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
// Mul and Div ops have identity value one while AddV2 and SubOp have identity // Mul and Div ops have identity value one while AddV2 and SubOp have identity
// value zero. // value zero.
int identity = int identity =
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value); (std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value ||
std::is_same<OpT, RealDivOp>::value);
Type element_ty = lhs_type.getElementType(); Type element_ty = lhs_type.getElementType();
Attribute identity_attr; Attribute identity_attr;
@ -537,6 +497,12 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
return arithmetic_op.x(); return arithmetic_op.x();
} }
auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
// TODO(chhe): we could fold and add an identity to force the broadcast.
if (result_op_type != rhs_type) {
return {};
}
bool is_symmetric = bool is_symmetric =
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value); (std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
if (auto attr = operands[0].dyn_cast_or_null<DenseElementsAttr>()) { if (auto attr = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
@ -1413,7 +1379,7 @@ static LogicalResult Verify(DynamicStitchOp op) {
auto expected_out_ty = auto expected_out_ty =
RankedTensorType::get(expected_shape, out_ty.getElementType()); RankedTensorType::get(expected_shape, out_ty.getElementType());
if (!AreCastCompatible(out_ty, expected_out_ty)) { if (!AreCastCompatible({out_ty, expected_out_ty})) {
return op.emitOpError() << "has invalid output type; should be " return op.emitOpError() << "has invalid output type; should be "
"compatible with inferred type " "compatible with inferred type "
<< expected_out_ty; << expected_out_ty;
@ -1647,7 +1613,7 @@ static ShapedType InferFillOpType(Value dims, Value value) {
llvm::SmallVector<int64_t, 4> shape; llvm::SmallVector<int64_t, 4> shape;
shape.reserve(dims_attr.getNumElements()); shape.reserve(dims_attr.getNumElements());
for (const APInt &dim : dims_attr.getValues<APInt>()) { for (const APInt dim : dims_attr.getValues<APInt>()) {
shape.push_back(dim.getSExtValue()); shape.push_back(dim.getSExtValue());
} }
return RankedTensorType::get(shape, etype); return RankedTensorType::get(shape, etype);
@ -1658,6 +1624,29 @@ void FillOp::build(OpBuilder &builder, OperationState &result, Value dims,
FillOp::build(builder, result, InferFillOpType(dims, value), dims, value); FillOp::build(builder, result, InferFillOpType(dims, value), dims, value);
} }
OpFoldResult FillOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "fill op has two operand");
auto value = operands[1].dyn_cast_or_null<ElementsAttr>();
if (!value) return {};
auto type = getType().cast<ShapedType>();
if (type.hasStaticShape())
return DenseElementsAttr::get(type, value.getValue({}));
auto dims = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (!dims) return {};
llvm::SmallVector<int64_t, 4> shape;
shape.reserve(dims.getNumElements());
for (const APInt dim : dims.getValues<APInt>()) {
shape.push_back(dim.getSExtValue());
}
type = RankedTensorType::get(shape, type.getElementType());
return DenseElementsAttr::get(type, value.getValue({}));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// FusedBatchNormGradOp // FusedBatchNormGradOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1814,14 +1803,14 @@ static LogicalResult Verify(IfOp op) {
for (unsigned i = 0; i < expectedNumInputs; ++i) { for (unsigned i = 0; i < expectedNumInputs; ++i) {
auto operandType = op.getOperand(i + 1).getType().cast<TensorType>(); auto operandType = op.getOperand(i + 1).getType().cast<TensorType>();
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>(); auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
if (!AreCastCompatible(operandType, thenInputType)) if (!AreCastCompatible({operandType, thenInputType}))
return op.emitError( return op.emitError(
llvm::formatv("then branch input type {0} is incompatible with " llvm::formatv("then branch input type {0} is incompatible with "
"operand type {1} at index {2}", "operand type {1} at index {2}",
thenInputType, operandType, i)); thenInputType, operandType, i));
auto elseInputType = elseFuncType.getInput(i).cast<TensorType>(); auto elseInputType = elseFuncType.getInput(i).cast<TensorType>();
if (!AreCastCompatible(operandType, elseInputType)) if (!AreCastCompatible({operandType, elseInputType}))
return op.emitError( return op.emitError(
llvm::formatv("else branch input type {0} is incompatible with " llvm::formatv("else branch input type {0} is incompatible with "
"operand type {1} at index {2}", "operand type {1} at index {2}",
@ -1829,7 +1818,7 @@ static LogicalResult Verify(IfOp op) {
// If branches have incompatible input types that means that no tensor can // If branches have incompatible input types that means that no tensor can
// serve as input to both the functions. Hence, the op is invalid. // serve as input to both the functions. Hence, the op is invalid.
if (!AreCastCompatible(thenInputType, elseInputType)) if (!AreCastCompatible({thenInputType, elseInputType}))
return op.emitError(llvm::formatv( return op.emitError(llvm::formatv(
"branches inputs have incompatible types {0} and {1} at index {2}", "branches inputs have incompatible types {0} and {1} at index {2}",
thenInputType, elseInputType, i)); thenInputType, elseInputType, i));
@ -1845,14 +1834,14 @@ static LogicalResult Verify(IfOp op) {
for (unsigned i = 0; i < expectedNumResults; ++i) { for (unsigned i = 0; i < expectedNumResults; ++i) {
auto resultType = op.getResult(i).getType().cast<TensorType>(); auto resultType = op.getResult(i).getType().cast<TensorType>();
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>(); auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
if (!AreCastCompatible(thenResultType, resultType)) if (!AreCastCompatible({thenResultType, resultType}))
return op.emitError( return op.emitError(
llvm::formatv("then branch result type {0} is incompatible with op " llvm::formatv("then branch result type {0} is incompatible with op "
"result type {1} at index {2}", "result type {1} at index {2}",
thenResultType, resultType, i)); thenResultType, resultType, i));
auto elseResultType = elseFuncType.getResult(i).cast<TensorType>(); auto elseResultType = elseFuncType.getResult(i).cast<TensorType>();
if (!AreCastCompatible(elseResultType, resultType)) if (!AreCastCompatible({elseResultType, resultType}))
return op.emitError( return op.emitError(
llvm::formatv("else branch result type {0} is incompatible with op " llvm::formatv("else branch result type {0} is incompatible with op "
"result type {1} at index {2}", "result type {1} at index {2}",
@ -2426,6 +2415,10 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<RealDivWithSqrtDivisor>(context); results.insert<RealDivWithSqrtDivisor>(context);
} }
OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReshapeOp // ReshapeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2616,9 +2609,12 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
<< variadic_idx_str << " to match rank of operand" << variadic_idx_str << " to match rank of operand"
<< variadic_idx_str; << variadic_idx_str;
} else if (result_ranked_type.hasStaticShape()) { } else if (result_ranked_type.hasStaticShape()) {
// The operand is an unranked tensor, verify that the result is dynamic. // The operand is an unranked tensor, print a warning if the result
return op->emitOpError("requires dynamic shape result") // is static.
<< variadic_idx_str << " for unranked operand" << variadic_idx_str; // Note: We do not handle this situation as an error, this would be too
// restrictive due to incompleteness of shape inference at this point.
op->emitWarning("has static shape result")
<< variadic_idx_str << " for unranked operand" << variadic_idx_str;
} }
Type element_type = result_ranked_type.getElementType(); Type element_type = result_ranked_type.getElementType();
@ -3789,7 +3785,7 @@ static LogicalResult Verify(WhileOp op) {
auto aType = a.second[idx]; auto aType = a.second[idx];
auto bType = b.second[idx]; auto bType = b.second[idx];
if (!AreCastCompatible(aType, bType)) if (!AreCastCompatible({aType, bType}))
return op.emitError(llvm::formatv( return op.emitError(llvm::formatv(
"{0} type {1} is incompatible with {2} type {3} at index {4}", "{0} type {1} is incompatible with {2} type {3} at index {4}",
a.first, aType, b.first, bType, idx)); a.first, aType, b.first, bType, idx));

View File

@ -31,7 +31,7 @@ limitations under the License.
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"

View File

@ -905,5 +905,29 @@ def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> {
TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>; TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>;
} }
// TODO(b/156507832): Move tf.InplaceUpdate to tf_generated_ops.td once
// autogenerated op def matches.
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
let summary = "Updates specified rows 'i' with values 'v'.";
let description = [{
Computes `x[i, :] = v; return x`.
Originally this function is mutative however for compilation we make this
operation create / operate on a copy of `x`.
}];
let arguments = (ins
TF_Tensor:$x,
I32Tensor:$i,
TF_Tensor:$v
);
let results = (outs
TF_Tensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
#endif // TF_OPS #endif // TF_OPS

View File

@ -28,6 +28,134 @@ llvm::Optional<llvm::ArrayRef<int64_t>> GetShape(mlir::Value value) {
if (shaped_type.hasRank()) return shaped_type.getShape(); if (shaped_type.hasRank()) return shaped_type.getShape();
return llvm::None; return llvm::None;
} }
// Merges cast compatible shapes and returns a more refined shape. The two
// shapes are cast compatible if they have the same rank and at each dimension,
// either both have same size or one of them is dynamic. Returns false if the
// given shapes are not cast compatible. The refined shape is same or more
// precise than the two input shapes.
bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,
llvm::ArrayRef<int64_t> b_shape,
llvm::SmallVectorImpl<int64_t>* refined_shape) {
if (a_shape.size() != b_shape.size()) return false;
int64_t rank = a_shape.size();
refined_shape->reserve(rank);
for (auto dims : llvm::zip(a_shape, b_shape)) {
int64_t dim1 = std::get<0>(dims);
int64_t dim2 = std::get<1>(dims);
if (mlir::ShapedType::isDynamic(dim1)) {
refined_shape->push_back(dim2);
continue;
}
if (mlir::ShapedType::isDynamic(dim2)) {
refined_shape->push_back(dim1);
continue;
}
if (dim1 == dim2) {
refined_shape->push_back(dim1);
continue;
}
return false;
}
return true;
}
// Given two types `a` and `b`, returns a refined type which is cast compatible
// with both `a` and `b` and is equal to or more precise than both of them. It
// returns empty Type if the input types are not cast compatible.
//
// The two types are considered cast compatible if they have dynamically equal
// shapes and element type. For element types that do not have subtypes, they
// must be equal. However for TensorFlow types such as Resource and Variant,
// that also have subtypes, we recursively check for subtype compatibilty for
// Resource types and assume all variant types are cast compatible. If either
// one of `a` or `b` have empty subtypes, they are considered cast compatible.
//
// The returned type is same or more precise than the input types. For example,
// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
// tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
//
// Provides option to ignore ref types on 'a'. This is useful for TF ops that
// might allow operands to either be same as result type or be a ref type
// corresponding to it.
mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
bool may_ignore_ref_type_a) {
// Fast path if everything is equal.
if (a == b) return b;
auto a_tt = a.dyn_cast<mlir::TensorType>();
auto b_tt = b.dyn_cast<mlir::TensorType>();
// If only one of a or b is a tensor type, they are incompatible.
if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
// For non-tensor types, we do not need to worry about shape and can return
// early.
if (!a_tt && !b_tt) {
// Remove ref types.
if (may_ignore_ref_type_a) {
if (auto ref_type = a.dyn_cast<mlir::TF::TensorFlowRefType>()) {
a = ref_type.RemoveRef();
if (a == b) return a;
}
}
if (a.getKind() != b.getKind()) return nullptr;
// If either is not a type that contain subtypes then the types are not cast
// compatible.
auto a_wst = a.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
auto b_wst = b.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
if (!a_wst || !b_wst) return nullptr;
// For Variant types we are more permissive right now and accept all pairs
// of Variant types. If we are more constrainted and check compatibility of
// subtypes, we might reject valid graphs.
// TODO(prakalps): Variant doesn't have a subtype, we assign it
// one, so we should only assign it one when we know the subtype. Then we
// can be more constrained and check subtypes for cast compatibility as
// well.
if (a.isa<mlir::TF::VariantType>()) return a;
// For Resource types, we recursively check the subtypes for cast
// compatibility, if possible. Otherwise treat them as compatible.
auto a_wst_st = a_wst.GetSubtypes();
auto b_wst_st = b_wst.GetSubtypes();
if (a_wst_st.empty() || b_wst_st.empty()) return a;
if (a_wst_st.size() != b_wst_st.size()) return nullptr;
llvm::SmallVector<mlir::TensorType, 4> refined_subtypes;
for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) {
mlir::Type refined_st =
GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
/*may_ignore_ref_type_a=*/false);
if (!refined_st) return nullptr;
refined_subtypes.push_back(refined_st.cast<mlir::TensorType>());
}
return mlir::TF::ResourceType::get(refined_subtypes, a.getContext());
}
// For tensor types, check compatibility of both element type and shape.
mlir::Type refined_element_ty = GetCastCompatibleType(
a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
if (!refined_element_ty) return nullptr;
if (!a_tt.hasRank() && !b_tt.hasRank()) {
return mlir::UnrankedTensorType::get(refined_element_ty);
}
if (!a_tt.hasRank()) {
return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty);
}
if (!b_tt.hasRank()) {
return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty);
}
llvm::SmallVector<int64_t, 8> refined_shape;
if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
return nullptr;
return mlir::RankedTensorType::get(refined_shape, refined_element_ty);
}
} // namespace } // namespace
namespace mlir { namespace mlir {
@ -224,44 +352,16 @@ bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
bool HasCompatibleElementTypes(Type lhs, Type rhs, bool HasCompatibleElementTypes(Type lhs, Type rhs,
bool may_ignore_ref_type_lhs) { bool may_ignore_ref_type_lhs) {
// Fast path if everything is equal. return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
if (lhs == rhs) return true; }
// In TF all values are tensors. bool AreCastCompatible(ArrayRef<Type> types) {
auto lhs_tt = lhs.cast<TensorType>(); Type common = types.front();
auto rhs_tt = rhs.cast<TensorType>(); for (auto type : types.drop_front()) {
Type refined_type =
// Verify matching element types. These should be identical dynamically, GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
// so this allows for types not yet fully refined. if (!refined_type) return false;
auto lhs_et = lhs_tt.getElementType(); common = refined_type;
auto rhs_et = rhs_tt.getElementType();
if (lhs_et == rhs_et) return true;
// Remove ref types.
if (may_ignore_ref_type_lhs) {
if (auto ref_type = lhs_et.dyn_cast<TF::TensorFlowRefType>()) {
lhs_et = ref_type.RemoveRef();
if (lhs_et == rhs_et) return true;
}
}
if (lhs_et.getKind() != rhs_et.getKind()) return false;
// If either is not type that contain subtypes then the element types don't
// match.
auto lhs_wst = lhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
auto rhs_wst = rhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
if (!lhs_wst || !rhs_wst) return false;
// Consider the subtype recursively.
auto lhs_wst_st = lhs_wst.GetSubtypes();
auto rhs_wst_st = rhs_wst.GetSubtypes();
if (lhs_wst_st.empty() || rhs_wst_st.empty()) return true;
if (lhs_wst_st.size() != rhs_wst_st.size()) return false;
for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) {
if (!HasCompatibleElementTypes(std::get<0>(subtypes),
std::get<1>(subtypes)))
return false;
} }
return true; return true;
} }

View File

@ -313,6 +313,12 @@ bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs);
bool HasCompatibleElementTypes(Type lhs, Type rhs, bool HasCompatibleElementTypes(Type lhs, Type rhs,
bool may_ignore_ref_type_lhs = false); bool may_ignore_ref_type_lhs = false);
// Returns true if all TensorFlow types can be cast to one
// another. In other words, a single run-time value is legal for both the types.
// For example, tensor<*xf32>, tensor<?xf32> and tensor<3xf32> are cast
// compatible.
bool AreCastCompatible(ArrayRef<Type> types);
} // end namespace TF } // end namespace TF
} // end namespace mlir } // end namespace mlir

View File

@ -471,3 +471,14 @@ func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor<i32> {
// CHECK: return [[VAL0]] // CHECK: return [[VAL0]]
return %0 : tensor<i32> return %0 : tensor<i32>
} }
// CHECK-LABEL: @foldFill
func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>) {
%0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = "tf.Const"() {value = dense<23.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>}
%2 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor<f32>) -> tensor<3x2x1xf32>
// CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>}
%3 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor<f32>) -> tensor<*xf32>
return %2, %3 : tensor<3x2x1xf32>, tensor<*xf32>
}

Some files were not shown because too many files have changed in this diff Show More