Merge branch 'master' into resolve-exhaustive-test-compile-issue

This commit is contained in:
David Norman 2019-07-12 12:08:01 +01:00
commit 20f510ceae
604 changed files with 18795 additions and 8108 deletions

View File

@ -40,6 +40,7 @@ build:mkl -c opt
# This config option is used to enable MKL-DNN open source library only, # This config option is used to enable MKL-DNN open source library only,
# without depending on MKL binary version. # without depending on MKL binary version.
build:mkl_open_source_only --define=build_with_mkl_dnn_only=true build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
build:mkl_open_source_only --define=build_with_mkl_dnn_v1_only=true
build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0

View File

@ -1,3 +1,274 @@
# Release 1.14.0
## Major Features and Improvements
* This is the first 1.x release containing the compat.v2 module. This module
is required to allow libraries to publish code which works in both 1.x and
2.x. After this release, no backwards incompatible changes are allowed in
the 2.0 Python API.
* Turn on MKL-DNN contraction kernels by default. MKL-DNN dynamically
dispatches the best kernel implementation based on CPU vector architecture.
To disable them, build with --define=tensorflow_mkldnn_contraction_kernel=0.
## Behavioral changes
* Set default loss reduction as `AUTO` for improving reliability of loss
scaling with distribution strategy and custom training loops. `AUTO`
indicates that the reduction option will be determined by the usage context.
For almost all cases this defaults to `SUM_OVER_BATCH_SIZE`. When used in
distribution strategy scope, outside of built-in training loops such as
`tf.keras` `compile` and `fit`, we expect reduction value to be 'None' or
'SUM'. Using other values will raise an error.
* Wraps losses passed to the `compile` API (strings and v1 losses) which are
not instances of v2 `Loss` class in `LossWrapper` class. => All losses will
now use `SUM_OVER_BATCH_SIZE` reduction as default.
* Disable `run_eagerly` and distribution strategy if there are symbolic
tensors added to the model using `add_metric` or `add_loss`.
* tf.linspace(start, stop, num) now always uses "stop" as last value (for
num > 1)
* `ResourceVariable` and `Variable` no longer accepts `constraint` in the
constructor, nor expose it as a @property.
* The behavior of tf.gather is now correct when axis=None and batch_dims<0.
* Only create a GCS directory object if the object does not already exist.
* In `map_vectorization` optimization, reduce the degree of parallelism in the
vectorized map node.
* Bug fix: loss and gradients should now more reliably be correctly scaled
w.r.t. the global batch size when using a tf.distribute.Strategy.
* Updating cosine similarity loss - removed the negate sign from cosine
similarity.
* DType is no longer convertible to an int. Use dtype.as_datatype_enum instead
of int(dtype) to get the same result.
* Changed default for gradient accumulation for TPU embeddings to true.
* Callbacks now log values in eager mode when a deferred build model is used.
* Transitive dependencies on :pooling_ops were removed. Some users may need to
add explicit dependencies on :pooling_ops if they reference the operators
from that library.
## Bug Fixes and Other Changes
* Documentation
* Deprecations and Symbol renames.
* Remove unused StringViewVariantWrapper
* Delete unused Fingerprint64Map op registration
* SignatureDef util functions have been deprecated.
* Renamed tf.image functions to remove duplicate "image" where it is
redundant.
* tf.keras.experimental.export renamed to
tf.keras.experimental.export_saved_model
* Standardize the LayerNormalization API by replacing the args `norm_axis`
and `params_axis` with `axis`.
* Tensor::UnsafeCopyFromInternal deprecated in favor Tensor::BitcastFrom
* Keras & Python API
* Add v2 module aliases for:
* tf.initializers => tf.keras.initializers
* tf.losses => tf.keras.losses & tf.metrics => tf.keras.metrics
* tf.optimizers => tf.keras.optimizers
* Add tf.keras.layers.AbstractRNNCell as the preferred implementation of
RNN cell for TF v2. User can use it to implement RNN cell with custom
behavior.
* Adding `clear_losses` API to be able to clear losses at the end of
forward pass in a custom training loop in eager.
* Add support for passing list of lists to the `metrics` param in Keras
`compile`.
* Added top-k to precision and recall to keras metrics.
* Adding public APIs for `cumsum` and `cumprod` keras backend functions.
* Fix: model.add_loss(symbolic_tensor) should work in ambient eager.
* Add name argument to tf.string_split and tf.strings_split
* Minor change to SavedModels exported from Keras using
tf.keras.experimental.export. (SignatureDef key for evaluation mode is
now "eval" instead of "test"). This will be reverted back to "test" in
the near future.
* Updates binary cross entropy logic in Keras when input is probabilities.
Instead of converting probabilities to logits, we are using the cross
entropy formula for probabilities.
* Raw TensorFlow functions can now be used in conjunction with the Keras
Functional API during model creation. This obviates the need for users
to create Lambda layers in most cases when using the Functional API.
Like Lambda layers, TensorFlow functions that result in Variable
creation or assign ops are not supported.
* Keras training and validation curves are shown on the same plot.
* Introduce `dynamic` constructor argument in Layer and Model, which
should be set to True when using imperative control flow in the `call`
method.
* Removing of dtype in the constructor of initializers and partition_info
in call.
* New ops and improved op functionality
* Add OpKernels for some stateless maps
* Add v2 APIs for AUCCurve and AUCSummationMethod
enums. #tf-metrics-convergence
* Add tf.math.nextafter op.
* Add CompositeTensor base class.
* Add tf.linalg.tridiagonal_solve op.
* Add opkernel templates for common table operations.
* Added support for TFLite in TensorFlow 2.0.
* Adds summary trace API for collecting graph and profile information.
* Add batch_dims argument to tf.gather.
* Add support for `add_metric` in the graph function mode.
* Add C++ Gradient for BatchMatMulV2.
* Added tf.random.binomial
* Added gradient for SparseToDense op.
* Add legacy string flat hash map op kernels
* Add a ragged size op and register it to the op dispatcher
* Add broadcasting support to tf.matmul.
* Add ellipsis (...) support for tf.einsum()
* Added LinearOperator.adjoint and LinearOperator.H (alias).
* Added GPU implementation of tf.linalg.tridiagonal_solve.
* Added strings.byte_split
* Add RaggedTensor.placeholder()
* Add a new "result_type" parameter to tf.strings.split
* `add_update` can now be passed a zero-arg callable in order to support
turning off the update when setting `trainable=False` on a Layer of a
Model compiled with `run_eagerly=True`.
* Add variant wrapper for absl::string_view
* Add expand_composites argument to all nest.* methods.
* Add pfor converter for Squeeze.
* Bug fix for tf.tile gradient
* Expose CriticalSection in core as tf.CriticalSection.
* Update Fingerprint64Map to use aliases
* ResourceVariable support for gather_nd.
* ResourceVariable's gather op supports batch dimensions.
* Variadic reduce is supported on CPU
* Extend tf.function with basic support for CompositeTensors arguments
(such as SparseTensor and RaggedTensor).
* Add templates and interfaces for creating lookup tables
* Post-training quantization tool supports quantizing weights shared by
multiple operations. The models made with versions of this tool will use
INT8 types for weights and will only be executable interpreters from
this version onwards.
* Malformed gif images could result in an access out of bounds in the
color palette of the frame. This has been fixed now
* image.resize now considers proper pixel centers and has new kernels
(incl. anti-aliasing).
* Performance
* Turn on MKL-DNN contraction kernels by default. MKL-DNN dynamically
dispatches the best kernel implementation based on CPU vector
architecture. To disable them, build with
--define=tensorflow_mkldnn_contraction_kernel=0.
* Support for multi-host ncclAllReduce in Distribution Strategy.
* Expose a flag that allows the number of threads to vary across Python
benchmarks.
* TensorFlow 2.0 Development
* Add v2 sparse categorical crossentropy metric.
* Allow non-Tensors through v2 losses.
* Add UnifiedGRU as the new GRU implementation for tf2.0. Change the
default recurrent activation function for GRU from 'hard_sigmoid' to
'sigmoid', and 'reset_after' to True in 2.0. Historically recurrent
activation is 'hard_sigmoid' since it is fast than 'sigmoid'. With new
unified backend between CPU and GPU mode, since the CuDNN kernel is
using sigmoid, we change the default for CPU mode to sigmoid as well.
With that, the default GRU will be compatible with both CPU and GPU
kernel. This will enable user with GPU to use CuDNN kernel by default
and get a 10x performance boost in training. Note that this is
checkpoint breaking change. If user want to use their 1.x pre-trained
checkpoint, please construct the layer with
GRU(recurrent_activation='hard_sigmoid', reset_after=False) to fallback
to 1.x behavior.
* TF 2.0 - Update metric name to always reflect what the user has given in
compile. Affects following cases 1. When name is given as
'accuracy'/'crossentropy' 2. When an aliased function name is used eg.
'mse' 3. Removing the `weighted` prefix from weighted metric names.
* Begin adding Go wrapper for C Eager API
* image.resize in 2.0 now supports gradients for the new resize kernels.
* removed tf.string_split from v2 API
* Expose tf.contrib.proto.* ops in tf.io (they will exist in TF2)
* "Updates the TFLiteConverter API in 2.0. Changes from_concrete_function
to from_concrete_functions."
* Enable tf.distribute.experimental.MultiWorkerMirroredStrategy working in
eager mode.
* Support both binary and -1/1 label input in v2 hinge and squared hinge
losses.
* TensorFlow Lite
* "Adds support for tflite_convert in 2.0."
* "Remove lite.OpHint, lite.experimental, and lite.constant from 2.0 API."
* tf.contrib
* Added Neural Turing Implementation as described in
https://arxiv.org/abs/1807.08518.
* Remove tf.contrib.timeseries dependency on TF distributions.
* tf.data
* Add num_parallel_reads and passing in a Dataset containing filenames
into TextLineDataset and FixedLengthRecordDataset
* Going forward we operate in TF 2.0, this change is part of the effort to
slowly converting XYZDataset to DatasetV2 type which is the official
version going to be used in TF 2.0 and motivated by some compatibility
issue found, _BigtableXYZDataset (of type DatasetV2) does not implement
the _as_variant_tensor() of DatasetV1, when moving contrib.bigtable to
tensorflow_io. Converting into DatasetV2 removes the overheads to
maintain V1 while we are moving into TF 2.0.
* Add dataset ops to the graph (or create kernels in Eager execution)
during the python Dataset object creation instead doing it during
Iterator creation time.
* Add support for TensorArrays to tf.data Dataset.
* Switching tf.data functions to use `defun`, providing an escape hatch to
continue using the legacy `Defun`.
* Toolchains
* CUDNN_INSTALL_PATH, TENSORRT_INSTALL_PATH, NCCL_INSTALL_PATH,
NCCL_HDR_PATH are deprecated. Use TF_CUDA_PATHS instead which supports a
comma-separated list of base paths that are searched to find CUDA
libraries and headers.
* TF code now resides in `tensorflow_core` and `tensorflow` is just a
virtual pip package. No code changes are needed for projects using
TensorFlow, the change is transparent
* XLA
* XLA HLO graphs can be inspected with interactive_graphviz tool now.
* Estimator
* Use tf.compat.v1.estimator.inputs instead of tf.estimator.inputs
* Replace contrib references with tf.estimator.experimental.* for apis in
early_stopping.py
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
1e100, 4d55397500, a6802739, abenmao, Adam Weiss, Ag Ramesh, Alan Du, Albin Joy,
Alex, Aman Patel, Amit, Amit Kumar Jaiswal, Amit Srivastava, Andreas Eberle,
Andy Craze, Anthony Platanios, Armen Poghosov, armenpoghosov, arp95, Arpit Shah,
Ashwin Ramaswami, Aurelien Geron, AuréLien Geron, aweers, awesomealex1, Ayush
Agrawal, Ben Barsdell, Bharat Raghunathan, Bhavani Subramanian, blairhan,
BléNesi Attila, Brandon Carter, candy.dc, Chao Liu, chenchc, chie8842, Christian
Hansen, Christian Sigg, Clayne Robison, crafet, csukuangfj, ctiijima, Dan
Jarvis, Dan Lazewatsky, Daniel Ingram, Daniel Salvadori, Dave Airlie, David
Norman, Dayananda V, Dayananda-V, delock, Denis Khalikov, Deven Desai, Dheeraj
Rajaram Reddy, dmitrievanthony, Donovan Ong, Drew Szurko, Duncan Riach, Dustin
Neighly, Edward Forgacs, EFanZh, Fei Hu, Felix Lemke, Filip Matzner, fo40225,
frreiss, Gautam, gehring, Geoffrey Irving, Grzegorz George Pawelczak, Grzegorz
Pawelczak, Gyoung-Yoon Ryoo, HanGuo97, Hanton Yang, Hari Shankar, hehongliang,
Heungsub Lee, Hoeseong Kim, I-Hong Jhuo, Ilango R, Innovimax, Irene Dea, Jacky
Ko, Jakub Lipinski, Jason Zaman, jcf94, Jeffrey Poznanovic, Jens Elofsson,
Jeroen BéDorf, Jia Qingtong, Jiankang, Joe Q, Joe Quadrino, Joeran Beel, Jonas
Rauber, Jonathan, Jonathan Kyl, Joppe Geluykens, Joseph Friedman, jtressle, jwu,
K Yasaswi Sri Chandra Gandhi, K. Hodges, Kaixi Hou, Karl Lessard, Karl
Weinmeister, Karthik Muthuraman, Kashif Rasul, KDR, Keno Fischer, Kevin Mader,
kjopek, Koan-Sin Tan, kouml, ktaebum, Lakshay Tokas, Laurent Le Brun, Letian
Kang, Li, Guizi, Loo Rong Jie, Lucas Hendren, Lukas Geiger, Luke Han, luxupu,
Ma, Guokai, Mahmoud Abuzaina, Mandar Deshpande, manhyuk, Marco Gaido, Marek
Drozdowski, Mark Collier, Mark Ryan, mars20, Mateusz Chudyk, Matt Conley,
MattConley, mbhuiyan, mdfaijul, Melissa Grueter, Michael KäUfl, MickaëL
Schoentgen, Miguel Morin, Mihail Salnikov, Mike Arpaia, Mike Holcomb, monklof,
Moses Marin, Mshr-H, nammbash, Natalia Gimelshein, Nayana-Ibm, neargye, Neeraj
Pradhan, Nehal J Wani, Nick, Niels Ole Salscheider, Niranjan Hasabnis, nlewycky,
Nuka-137, Nutti, olicht, P Sudeepam, Palmer Lao, Pan Daoxin, Pariksheet Pinjari,
Pavel Samolysov, PENGWA, Pooya Davoodi, R S Nikhil Krishna, Rohit Gupta, Roman
Soldatow, rthadur, Ruizhe, Ryan Jiang, Samantha Andow, Sami Kama, Sana-Damani,
Saurabh Deoras, sdamani, seanshpark, Sebastien Iooss, Serv-Inc, Shahzad Lone,
Shashank Gupta, Shashi, shashvat, shashvatshahi1998, Siju, Siju Samuel,
Snease-Abq, Spencer Schaber, sremedios, srinivasan.narayanamoorthy, Steve Lang,
Steve Nesae, Sumesh Udayakumaran, Supriya Rao, Taylor Jakobson, Taylor Thornton,
Ted Chang, ThisIsPIRI, Thomas Deegan, Thomas Hagebols, tianyapiaozi, Tim Zaman,
tomguluson92, Tongxuan Liu, TungJerry, v1incent, Vagif, vcarpani, Vikram Tiwari,
Vishwak Srinivasan, Vitor-Alves, wangsiyu, wateryzephyr, WeberXie, WeijieSun,
Wen-Heng (Jack) Chung, wenxizhu, Will Battel, William D. Irons, wyzhao, Xin,
Yasuhiro Matsumoto, ymodak, Yong Tang, Younes Khoudli, Yuan Lin, Yves-Noel
Weweler, Zantares, zjjott, 卜居, 王振华 (Wang Zhenhua), 黄鑫
# Release 1.12.3
## Bug Fixes and Other Changes
* Updates `png_archive` dependency to 1.6.37 to not be affected by
CVE-2019-7317, CVE-2018-13785, and CVE-2018-14048.
* Updates `sqlite` depenency to 3.28.0 to not be affected by CVE-2018-20506,
CVE-2018-20346, and CVE-2018-20505.
# Release 1.12.2 # Release 1.12.2
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes

View File

@ -41,22 +41,30 @@ from tensorflow.python.tools import module_util as _module_util
# API IMPORTS PLACEHOLDER # API IMPORTS PLACEHOLDER
# WRAPPER_PLACEHOLDER
# Make sure directory containing top level submodules is in # Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works. # the __path__ so that "from tensorflow.foo import bar" works.
# We're using bitwise, but there's nothing special about that. # We're using bitwise, but there's nothing special about that.
_API_MODULE = bitwise # pylint: disable=undefined-variable _API_MODULE = _sys.modules[__name__].bitwise
_current_module = _sys.modules[__name__]
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) _tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
_current_module = _sys.modules[__name__]
if not hasattr(_current_module, '__path__'): if not hasattr(_current_module, '__path__'):
__path__ = [_tf_api_dir] __path__ = [_tf_api_dir]
elif _tf_api_dir not in __path__: elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir) __path__.append(_tf_api_dir)
# Hook external TensorFlow modules. # Hook external TensorFlow modules.
# Import compat before trying to import summary from tensorboard, so that
# reexport_tf_summary can get compat from sys.modules
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
try: try:
from tensorboard.summary._tf import summary from tensorboard.summary._tf import summary
_current_module.__path__ = ( _current_module.__path__ = (
[_module_util.get_parent_dir(summary)] + _current_module.__path__) [_module_util.get_parent_dir(summary)] + _current_module.__path__)
setattr(_current_module, "summary", summary)
except ImportError: except ImportError:
_logging.warning( _logging.warning(
"Limited tf.summary API due to missing TensorBoard installation.") "Limited tf.summary API due to missing TensorBoard installation.")
@ -65,6 +73,7 @@ try:
from tensorflow_estimator.python.estimator.api._v2 import estimator from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = ( _current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__) [_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError: except ImportError:
pass pass
@ -72,6 +81,7 @@ try:
from tensorflow.python.keras.api._v2 import keras from tensorflow.python.keras.api._v2 import keras
_current_module.__path__ = ( _current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__) [_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
except ImportError: except ImportError:
pass pass
@ -122,25 +132,17 @@ if _running_from_pip_package():
# pylint: disable=undefined-variable # pylint: disable=undefined-variable
try: try:
del python del python
if '__all__' in vars():
vars()['__all__'].remove('python')
del core
if '__all__' in vars():
vars()['__all__'].remove('core')
except NameError: except NameError:
# Don't fail if these modules are not available.
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
# does not have 'python', 'core' directories. Then, it will be copied
# to tensorflow/ which does have these two directories.
pass pass
# Similarly for compiler. Do it separately to make sure we do this even if the try:
# others don't exist. del core
except NameError:
pass
try: try:
del compiler del compiler
if '__all__' in vars():
vars()['__all__'].remove('compiler')
except NameError: except NameError:
pass pass
# pylint: enable=undefined-variable
# Add module aliases # Add module aliases
if hasattr(_current_module, 'keras'): if hasattr(_current_module, 'keras'):
@ -148,6 +150,8 @@ if hasattr(_current_module, 'keras'):
metrics = keras.metrics metrics = keras.metrics
optimizers = keras.optimizers optimizers = keras.optimizers
initializers = keras.initializers initializers = keras.initializers
setattr(_current_module, "losses", losses)
compat.v2.compat.v1 = compat.v1 setattr(_current_module, "metrics", metrics)
setattr(_current_module, "optimizers", optimizers)
setattr(_current_module, "initializers", initializers)
# pylint: enable=undefined-variable # pylint: enable=undefined-variable

View File

@ -30,10 +30,12 @@ from tensorflow.python.tools import module_util as _module_util
# API IMPORTS PLACEHOLDER # API IMPORTS PLACEHOLDER
# WRAPPER_PLACEHOLDER
# Make sure directory containing top level submodules is in # Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works. # the __path__ so that "from tensorflow.foo import bar" works.
# We're using bitwise, but there's nothing special about that. # We're using bitwise, but there's nothing special about that.
_API_MODULE = bitwise # pylint: disable=undefined-variable _API_MODULE = _sys.modules[__name__].bitwise # pylint: disable=undefined-variable
_current_module = _sys.modules[__name__] _current_module = _sys.modules[__name__]
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) _tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
if not hasattr(_current_module, '__path__'): if not hasattr(_current_module, '__path__'):
@ -46,6 +48,7 @@ try:
from tensorflow_estimator.python.estimator.api._v1 import estimator from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = ( _current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__) [_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError: except ImportError:
pass pass
@ -53,6 +56,7 @@ try:
from tensorflow.python.keras.api._v1 import keras from tensorflow.python.keras.api._v1 import keras
_current_module.__path__ = ( _current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__) [_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
except ImportError: except ImportError:
pass pass
@ -77,9 +81,8 @@ if '__all__' in vars():
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
# The 'app' module will be imported as part of the placeholder section above. # The 'app' module will be imported as part of the placeholder section above.
app.flags = flags # pylint: disable=undefined-variable _current_module.app.flags = flags # pylint: disable=undefined-variable
if '__all__' in vars(): setattr(_current_module, "flags", flags)
vars()['__all__'].append('flags')
# Load all plugin libraries from site-packages/tensorflow-plugins if we are # Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip. # running under pip.
@ -122,25 +125,16 @@ if _running_from_pip_package():
# pylint: disable=undefined-variable # pylint: disable=undefined-variable
try: try:
del python del python
if '__all__' in vars():
vars()['__all__'].remove('python')
del core
if '__all__' in vars():
vars()['__all__'].remove('core')
except NameError: except NameError:
# Don't fail if these modules are not available.
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
# does not have 'python', 'core' directories. Then, it will be copied
# to tensorflow/ which does have these two directories.
pass pass
# Similarly for compiler. Do it separately to make sure we do this even if the try:
# others don't exist. del core
except NameError:
pass
try: try:
del compiler del compiler
if '__all__' in vars():
vars()['__all__'].remove('compiler')
except NameError: except NameError:
pass pass
compat.v2.compat.v1 = compat.v1 _current_module.compat.v2.compat.v1 = _current_module.compat.v1
# pylint: enable=undefined-variable # pylint: enable=undefined-variable

View File

@ -64,6 +64,7 @@ tf_cuda_library(
}) + [ }) + [
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:remote_mgr",
"//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel", "//tensorflow/core/distributed_runtime/rpc:grpc_channel",

View File

@ -54,6 +54,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/rendezvous.h"
@ -260,12 +261,14 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
auto* device_mgr = grpc_server->worker_env()->device_mgr; auto* device_mgr = grpc_server->worker_env()->device_mgr;
auto remote_mgr =
absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/true);
return ctx->context->InitializeRemoteMaster( return ctx->context->InitializeRemoteMaster(
std::move(server), grpc_server->worker_env(), worker_session, std::move(server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(remote_device_mgr), std::move(remote_eager_workers), std::move(remote_device_mgr),
remote_workers, context_id, r, device_mgr, keep_alive_secs, remote_workers, context_id, r, device_mgr, keep_alive_secs,
worker_session->cluster_flr.get()); worker_session->cluster_flr.get(), std::move(remote_mgr));
#undef LOG_AND_RETURN_IF_ERROR #undef LOG_AND_RETURN_IF_ERROR
} }
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM

View File

@ -890,9 +890,19 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
// Stop the tape from recording // Stop the tape from recording
pop_backward_tape.release()(); pop_backward_tape.release()();
if (grad.size() != in_grads.size()) {
return tensorflow::errors::Internal("Wrong number of gradients returned.");
}
std::vector<int64> targets; std::vector<int64> targets;
std::vector<Gradient*> used_in_grads;
// We may end up with slightly fewer elements than we reserve, but grad.size()
// should be a reasonably tight upper bound.
targets.reserve(grad.size());
used_in_grads.reserve(grad.size());
std::unordered_map<int64, TapeTensor> sources_that_are_targets; std::unordered_map<int64, TapeTensor> sources_that_are_targets;
for (Gradient* grad_tensor : grad) { for (int grad_index = 0; grad_index < grad.size(); ++grad_index) {
Gradient* grad_tensor = grad[grad_index];
if (grad_tensor != nullptr) { if (grad_tensor != nullptr) {
int64 tensor_id = vspace_.TensorId(grad_tensor); int64 tensor_id = vspace_.TensorId(grad_tensor);
targets.push_back(tensor_id); targets.push_back(tensor_id);
@ -900,23 +910,18 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
sources_that_are_targets.emplace( sources_that_are_targets.emplace(
tensor_id, vspace_.TapeTensorFromGradient(grad_tensor)); tensor_id, vspace_.TapeTensorFromGradient(grad_tensor));
} }
} Gradient* in_grad = in_grads[grad_index];
} if (in_grad != nullptr) {
if (targets.size() > in_grads.size()) {
return tensorflow::errors::Internal("Too many gradients returned.");
}
for (int target_index = 0; target_index < targets.size(); ++target_index) {
Gradient* in_grad = in_grads[target_index];
Gradient* grad_tensor = grad[target_index];
if (grad_tensor != nullptr && in_grad != nullptr) {
// ComputeGradient steals a reference // ComputeGradient steals a reference
vspace_.MarkAsResult(in_grad); vspace_.MarkAsResult(in_grad);
} }
used_in_grads.push_back(in_grad);
}
} }
return tape->ComputeGradient(vspace_, targets, sources, return tape->ComputeGradient(vspace_, targets, sources,
sources_that_are_targets, in_grads, out_grads); sources_that_are_targets, used_in_grads,
out_grads);
} }
template <typename Gradient, typename BackwardFunction, typename TapeTensor> template <typename Gradient, typename BackwardFunction, typename TapeTensor>

View File

@ -28,12 +28,16 @@ from tensorflow.python.tools import module_util as _module_util
# API IMPORTS PLACEHOLDER # API IMPORTS PLACEHOLDER
# WRAPPER_PLACEHOLDER
# Hook external TensorFlow modules. # Hook external TensorFlow modules.
_current_module = _sys.modules[__name__] _current_module = _sys.modules[__name__]
try: try:
from tensorboard.summary._tf import summary from tensorboard.summary._tf import summary
_current_module.__path__ = ( _current_module.__path__ = (
[_module_util.get_parent_dir(summary)] + _current_module.__path__) [_module_util.get_parent_dir(summary)] + _current_module.__path__)
# Make sure we get the correct summary module with lazy loading
setattr(_current_module, "summary", summary)
except ImportError: except ImportError:
_logging.warning( _logging.warning(
"Limited tf.compat.v2.summary API due to missing TensorBoard " "Limited tf.compat.v2.summary API due to missing TensorBoard "
@ -43,6 +47,7 @@ try:
from tensorflow_estimator.python.estimator.api._v2 import estimator from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = ( _current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__) [_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError: except ImportError:
pass pass
@ -50,6 +55,7 @@ try:
from tensorflow.python.keras.api._v2 import keras from tensorflow.python.keras.api._v2 import keras
_current_module.__path__ = ( _current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__) [_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
except ImportError: except ImportError:
pass pass
@ -61,11 +67,15 @@ except ImportError:
# #
# This make this one symbol available directly. # This make this one symbol available directly.
from tensorflow.python.compat.v2_compat import enable_v2_behavior # pylint: disable=g-import-not-at-top from tensorflow.python.compat.v2_compat import enable_v2_behavior # pylint: disable=g-import-not-at-top
setattr(_current_module, "enable_v2_behavior", enable_v2_behavior)
# Add module aliases # Add module aliases
_current_module = _sys.modules[__name__]
if hasattr(_current_module, 'keras'): if hasattr(_current_module, 'keras'):
losses = keras.losses losses = keras.losses
metrics = keras.metrics metrics = keras.metrics
optimizers = keras.optimizers optimizers = keras.optimizers
initializers = keras.initializers initializers = keras.initializers
setattr(_current_module, "losses", losses)
setattr(_current_module, "metrics", metrics)
setattr(_current_module, "optimizers", optimizers)
setattr(_current_module, "initializers", initializers)

View File

@ -27,12 +27,15 @@ from tensorflow.python.tools import module_util as _module_util
# API IMPORTS PLACEHOLDER # API IMPORTS PLACEHOLDER
# WRAPPER_PLACEHOLDER
# Hook external TensorFlow modules. # Hook external TensorFlow modules.
_current_module = _sys.modules[__name__] _current_module = _sys.modules[__name__]
try: try:
from tensorflow_estimator.python.estimator.api._v1 import estimator from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = ( _current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__) [_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError: except ImportError:
pass pass
@ -40,9 +43,11 @@ try:
from tensorflow.python.keras.api._v1 import keras from tensorflow.python.keras.api._v1 import keras
_current_module.__path__ = ( _current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__) [_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
except ImportError: except ImportError:
pass pass
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable _current_module.app.flags = flags # pylint: disable=undefined-variable
setattr(_current_module, "flags", flags)

View File

@ -22,6 +22,7 @@ load(
"tf_cc_test", "tf_cc_test",
"tf_copts", "tf_copts",
) )
load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags")
def tf_library( def tf_library(
name, name,
@ -180,13 +181,7 @@ def tf_library(
# `find` on such an object. # `find` on such an object.
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1 need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
# Pass --target_cpu=haswell to tfcompile if compiling for Haswell (bazel flags = tfcompile_extra_flags() + flags
# build --cpu=haswell). We put it at the beginning of the flags list so
# that tfcompile_flags can override if if desired.
flags = select({
"//tools/target_cpu:haswell": "--target_cpu=haswell ",
"//conditions:default": "",
}) + flags
if enable_xla_hlo_profiling: if enable_xla_hlo_profiling:
profiling_flag = "--xla_hlo_profile" profiling_flag = "--xla_hlo_profile"

View File

@ -1087,8 +1087,6 @@ Status Encapsulator::MakePrunedGraphCopyAndInline(
FunctionDefToBodyHelper(*fdef, node->attrs(), library, &fbody)); FunctionDefToBodyHelper(*fdef, node->attrs(), library, &fbody));
InlineFunctionBodyOptions inline_opts; InlineFunctionBodyOptions inline_opts;
inline_opts.override_device = false;
TF_RETURN_IF_ERROR(InlineFunctionBody(*library, pruned_graph->get(), node, TF_RETURN_IF_ERROR(InlineFunctionBody(*library, pruned_graph->get(), node,
fbody.get(), inline_opts)); fbody.get(), inline_opts));
} }

View File

@ -183,6 +183,9 @@ class XlaAssignVariableOp : public OpKernel {
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \ Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \
data::AnonymousIteratorHandleOp); \ data::AnonymousIteratorHandleOp); \
REGISTER_KERNEL_BUILDER( \
Name("DeleteIterator").Device(DEVICE).HostMemory("deleter"), \
data::DeleteIteratorOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
data::IteratorGetNextOp); \ data::IteratorGetNextOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \

View File

@ -462,9 +462,16 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
filegroup(
name = "tf_tfl_translate_main",
srcs = [
"tf_tfl_translate.cc",
],
)
tf_cc_binary( tf_cc_binary(
name = "tf_tfl_translate", name = "tf_tfl_translate",
srcs = ["tf_tfl_translate.cc"], srcs = [":tf_tfl_translate_main"],
deps = [ deps = [
":flatbuffer_translate_lib", ":flatbuffer_translate_lib",
":tensorflow_lite", ":tensorflow_lite",

View File

@ -26,11 +26,11 @@ namespace tflite {
// Error reporter that reports errors via the module's emitError. // Error reporter that reports errors via the module's emitError.
class EmitErrorReporter : public ErrorReporter { class EmitErrorReporter : public ErrorReporter {
public: public:
explicit EmitErrorReporter(mlir::Module module) : module_(module) {} explicit EmitErrorReporter(mlir::ModuleOp module) : module_(module) {}
int Report(const char* format, va_list args) override; int Report(const char* format, va_list args) override;
private: private:
mlir::Module module_; mlir::ModuleOp module_;
}; };
} // namespace tflite } // namespace tflite

View File

@ -42,6 +42,13 @@ static tflite::Padding ConvertTFL_PaddingAttrForOptionWriter(
.Case("VALID", tflite::Padding_VALID); .Case("VALID", tflite::Padding_VALID);
} }
static tflite::MirrorPadMode ConvertTFL_MirrorPaddingAttrForOptionWriter(
llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
return llvm::StringSwitch<tflite::MirrorPadMode>(str)
.Case("REFLECT", tflite::MirrorPadMode_REFLECT)
.Case("SYMMETRIC", tflite::MirrorPadMode_SYMMETRIC);
}
static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter( static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter(
mlir::Type type, flatbuffers::FlatBufferBuilder* builder) { mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
switch (type.getKind()) { switch (type.getKind()) {

View File

@ -77,9 +77,9 @@ using llvm::Twine;
using mlir::Block; using mlir::Block;
using mlir::Dialect; using mlir::Dialect;
using mlir::ElementsAttr; using mlir::ElementsAttr;
using mlir::Function; using mlir::FuncOp;
using mlir::MLIRContext; using mlir::MLIRContext;
using mlir::Module; using mlir::ModuleOp;
using mlir::NoneType; using mlir::NoneType;
using mlir::openOutputFile; using mlir::openOutputFile;
using mlir::Operation; using mlir::Operation;
@ -167,6 +167,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
return tflite::TensorType_STRING; return tflite::TensorType_STRING;
case mlir::TF::TensorFlowTypes::COMPLEX64: case mlir::TF::TensorFlowTypes::COMPLEX64:
return tflite::TensorType_COMPLEX64; return tflite::TensorType_COMPLEX64;
case mlir::TF::TensorFlowTypes::UINT8:
return tflite::TensorType_UINT8;
case mlir::StandardTypes::Integer: { case mlir::StandardTypes::Integer: {
const auto& itype = type.cast<mlir::IntegerType>(); const auto& itype = type.cast<mlir::IntegerType>();
switch (itype.getWidth()) { switch (itype.getWidth()) {
@ -243,18 +245,18 @@ static bool HasValidTFLiteType(Value* value, T& error_handler) {
// TODO(hinsu): Now that translation is done by making a single pass over the // TODO(hinsu): Now that translation is done by making a single pass over the
// MLIR module, consider inlining these validation checks at the place where // MLIR module, consider inlining these validation checks at the place where
// these invariants are assumed instead of checking upfront. // these invariants are assumed instead of checking upfront.
static bool IsValidTFLiteMlirModule(Module module) { static bool IsValidTFLiteMlirModule(ModuleOp module) {
MLIRContext* context = module.getContext(); MLIRContext* context = module.getContext();
// Verify that module has a function named main. // Verify that module has a function named main.
Function main_fn = module.getNamedFunction("main"); FuncOp main_fn = module.lookupSymbol<FuncOp>("main");
if (!main_fn) { if (!main_fn) {
return emitError(UnknownLoc::get(context), return emitError(UnknownLoc::get(context),
"should have a function named 'main'"), "should have a function named 'main'"),
false; false;
} }
for (auto fn : module.getOps<Function>()) { for (auto fn : module.getOps<FuncOp>()) {
if (fn.getBlocks().size() != 1) { if (fn.getBlocks().size() != 1) {
return fn.emitError("should have exactly one basic block"), false; return fn.emitError("should have exactly one basic block"), false;
} }
@ -323,14 +325,14 @@ class Translator {
// Translates the given MLIR module into TFLite FlatBuffer format and returns // Translates the given MLIR module into TFLite FlatBuffer format and returns
// the serialized output. Returns llvm::None on unsupported, invalid inputs or // the serialized output. Returns llvm::None on unsupported, invalid inputs or
// internal error. // internal error.
static Optional<std::string> Translate(Module module, static Optional<std::string> Translate(ModuleOp module,
bool emit_builtin_tflite_ops, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_select_tf_ops,
bool emit_custom_ops); bool emit_custom_ops);
private: private:
enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
explicit Translator(Module module, bool emit_builtin_tflite_ops, explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops) bool emit_select_tf_ops, bool emit_custom_ops)
: module_(module), builder_(kInitialBufferSize) { : module_(module), builder_(kInitialBufferSize) {
// The first buffer must be empty according to the schema definition. // The first buffer must be empty according to the schema definition.
@ -391,11 +393,11 @@ class Translator {
Operation* inst, const std::vector<int32_t>& operands, Operation* inst, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results); const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(Function fn); Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(FuncOp fn);
// Uses the tf.entry_function attribute (if set) to initialize the op to name // Uses the tf.entry_function attribute (if set) to initialize the op to name
// mapping. // mapping.
void InitializeNamesFromAttribute(Function fn); void InitializeNamesFromAttribute(FuncOp fn);
// Returns a unique name for `op`. // Returns a unique name for `op`.
std::string UniqueName(mlir::Operation* op); std::string UniqueName(mlir::Operation* op);
@ -403,7 +405,7 @@ class Translator {
// Returns a unique name starting with a given prefix. // Returns a unique name starting with a given prefix.
std::string UniqueName(llvm::StringRef prefix); std::string UniqueName(llvm::StringRef prefix);
Module module_; ModuleOp module_;
flatbuffers::FlatBufferBuilder builder_; flatbuffers::FlatBufferBuilder builder_;
BufferOffset<tflite::Buffer> empty_buffer_; BufferOffset<tflite::Buffer> empty_buffer_;
@ -449,10 +451,16 @@ std::string Translator::GetName(Operation* inst) {
} }
std::string Translator::UniqueName(llvm::StringRef prefix) { std::string Translator::UniqueName(llvm::StringRef prefix) {
// Keep incrementing the counter until we find a unique name.
std::string name = prefix; std::string name = prefix;
auto& val = name_to_count_[name]; int64_t& prefix_count = name_to_count_[name];
if (val) name = (prefix + llvm::Twine(val)).str(); int64_t val = prefix_count;
++val; while (val != 0) {
name = (prefix + llvm::Twine(prefix_count)).str();
++prefix_count;
val = name_to_count_[name];
}
name_to_count_[name] = 1;
return name; return name;
} }
@ -781,7 +789,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
llvm::None; llvm::None;
} }
void Translator::InitializeNamesFromAttribute(Function fn) { void Translator::InitializeNamesFromAttribute(FuncOp fn) {
auto dict_attr = fn.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function"); auto dict_attr = fn.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
if (!dict_attr) return; if (!dict_attr) return;
@ -794,8 +802,10 @@ void Translator::InitializeNamesFromAttribute(Function fn) {
fn.emitWarning() << "invalid entry function specification"; fn.emitWarning() << "invalid entry function specification";
return; return;
} }
for (auto it : llvm::enumerate(fn.getArguments())) for (auto it : llvm::enumerate(fn.getArguments())) {
op_to_name_[*it.value()->user_begin()] = input_names[it.index()]; op_to_name_[*it.value()->user_begin()] = input_names[it.index()];
++name_to_count_[input_names[it.index()].str()];
}
} }
if (auto str = dict_attr.get("outputs").dyn_cast<mlir::StringAttr>()) { if (auto str = dict_attr.get("outputs").dyn_cast<mlir::StringAttr>()) {
@ -813,18 +823,19 @@ void Translator::InitializeNamesFromAttribute(Function fn) {
// ensure the name that will be assigned to the buffer is the same, or // ensure the name that will be assigned to the buffer is the same, or
// insert an op so that we can have a buffer named such. This cannot // insert an op so that we can have a buffer named such. This cannot
// currently happen due to pseudo_input nodes. // currently happen due to pseudo_input nodes.
if (auto op = it.value()->getDefiningOp()) if (auto op = it.value()->getDefiningOp()) {
op_to_name_[op] = output_names[it.index()]; op_to_name_[op] = output_names[it.index()];
else name_to_count_[output_names[it.index()].str()] = 1;
} else {
fn.emitWarning() << "output is not due to an op and '" fn.emitWarning() << "output is not due to an op and '"
<< output_names[it.index()] << output_names[it.index()]
<< "' may not be a named output"; << "' may not be a named output";
} }
} }
}
} }
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph( Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
Function fn) {
InitializeNamesFromAttribute(fn); InitializeNamesFromAttribute(fn);
std::vector<BufferOffset<tflite::Tensor>> tensors; std::vector<BufferOffset<tflite::Tensor>> tensors;
llvm::DenseMap<Value*, int> tensor_index_map; llvm::DenseMap<Value*, int> tensor_index_map;
@ -927,7 +938,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
/*name=*/builder_.CreateString(fn.getName().str())); /*name=*/builder_.CreateString(fn.getName().str()));
} }
Optional<std::string> Translator::Translate(Module module, Optional<std::string> Translator::Translate(ModuleOp module,
bool emit_builtin_tflite_ops, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_select_tf_ops,
bool emit_custom_ops) { bool emit_custom_ops) {
@ -941,14 +952,14 @@ Optional<std::string> Translator::TranslateInternal() {
// Create a list of functions in the module with main function being the // Create a list of functions in the module with main function being the
// first function in the list. This is required as the first subgraph in the // first function in the list. This is required as the first subgraph in the
// model is entry point for the model. // model is entry point for the model.
std::vector<Function> functions; std::vector<FuncOp> functions;
functions.reserve(std::distance(module_.begin(), module_.end())); functions.reserve(std::distance(module_.begin(), module_.end()));
int subgraph_idx = 0; int subgraph_idx = 0;
Function main_fn = module_.getNamedFunction("main"); FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++; subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++;
functions.push_back(main_fn); functions.push_back(main_fn);
for (auto fn : module_.getOps<Function>()) { for (auto fn : module_.getOps<FuncOp>()) {
if (fn == main_fn) continue; if (fn == main_fn) continue;
subgraph_index_map_[fn.getName().str()] = subgraph_idx++; subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
@ -992,7 +1003,7 @@ Optional<std::string> Translator::TranslateInternal() {
// * Ops with variable tensors // * Ops with variable tensors
// //
bool tflite::MlirToFlatBufferTranslateFunction( bool tflite::MlirToFlatBufferTranslateFunction(
Module module, std::string* serialized_flatbuffer, ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
bool emit_custom_ops) { bool emit_custom_ops) {
auto maybe_translated = Translator::Translate( auto maybe_translated = Translator::Translate(
@ -1003,7 +1014,7 @@ bool tflite::MlirToFlatBufferTranslateFunction(
} }
static mlir::LogicalResult MlirToFlatBufferFileTranslateFunction( static mlir::LogicalResult MlirToFlatBufferFileTranslateFunction(
Module module, llvm::StringRef filename) { ModuleOp module, llvm::StringRef filename) {
std::string serialized_flatbuffer; std::string serialized_flatbuffer;
if (tflite::MlirToFlatBufferTranslateFunction( if (tflite::MlirToFlatBufferTranslateFunction(
module, &serialized_flatbuffer, emit_builtin_tflite_ops, module, &serialized_flatbuffer, emit_builtin_tflite_ops,

View File

@ -32,7 +32,7 @@ namespace tflite {
// Translates the given MLIR `module` into a FlatBuffer and stores the // Translates the given MLIR `module` into a FlatBuffer and stores the
// serialized flatbuffer into the string. // serialized flatbuffer into the string.
bool MlirToFlatBufferTranslateFunction(mlir::Module module, bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
std::string *serialized_flatbuffer, std::string *serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_select_tf_ops,

View File

@ -50,6 +50,13 @@ def TFL_Str : Type<CPred<"$_self.isa<mlir::TF::StringType>()">,
"TFLite string type">, "TFLite string type">,
BuildableType<"getType<mlir::TF::StringType>()">; BuildableType<"getType<mlir::TF::StringType>()">;
//===----------------------------------------------------------------------===//
// TFLite dialect uint8 type - uses the TF uint8 type as implementation
//===----------------------------------------------------------------------===//
def TFL_Uint8 : Type<CPred<"$_self.isa<mlir::TF::Uint8Type>()">,
"TFLite uint8 type">,
BuildableType<"getType<mlir::TF::Uint8Type>()">;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Activation function enum definitions. // Activation function enum definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -77,11 +84,17 @@ def TFL_AFAttr : StrEnumAttr<
// These should match the padding enum in TFLite schema. // These should match the padding enum in TFLite schema.
def TFL_PAD_Same : StrEnumAttrCase<"SAME">; def TFL_PAD_Same : StrEnumAttrCase<"SAME">;
def TFL_PAD_Valid : StrEnumAttrCase<"VALID">; def TFL_PAD_Valid : StrEnumAttrCase<"VALID">;
def TFL_MIRRORPAD_Reflect : StrEnumAttrCase<"REFLECT">;
def TFL_MIRRORPAD_Symmetric : StrEnumAttrCase<"SYMMETRIC">;
def TFL_PaddingAttr : StrEnumAttr<"Padding", "padding enum", [ def TFL_PaddingAttr : StrEnumAttr<"Padding", "padding enum", [
TFL_PAD_Same, TFL_PAD_Valid TFL_PAD_Same, TFL_PAD_Valid
]>; ]>;
def TFL_MirrorPaddingAttr : StrEnumAttr<"Padding", "Mirror pad enum", [
TFL_MIRRORPAD_Reflect, TFL_MIRRORPAD_Symmetric
]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Min-max range pair definitions. // Min-max range pair definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -432,13 +445,13 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
}]; }];
let arguments = ( let arguments = (
ins Variadic<TensorOf<[F32, I64, I32, I16, I8, TFL_QI8]>>:$values, ins Variadic<TensorOf<[F32, I64, I32, I16, I8, TFL_QI8, TFL_Uint8]>>:$values,
I32Attr:$axis, I32Attr:$axis,
TFL_AFAttr:$fused_activation_function TFL_AFAttr:$fused_activation_function
); );
let results = (outs let results = (outs
TensorOf<[F32, I64, I32, I16, I8, TFL_QI8]>:$output TensorOf<[F32, I64, I32, I16, I8, TFL_QI8, TFL_Uint8]>:$output
); );
let hasOptions = 1; let hasOptions = 1;
@ -553,9 +566,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [Broadcastable, NoSideEffect]> {
}]; }];
let arguments = ( let arguments = (
// TODO(haoliang): missing Uint8 ins TensorOf<[F32, I32, I64, I8, TFL_Uint8]>:$lhs,
ins TensorOf<[F32, I32, I64, I8]>:$lhs, TensorOf<[F32, I32, I64, I8, TFL_Uint8]>:$rhs);
TensorOf<[F32, I32, I64, I8]>:$rhs);
let results = (outs TFL_BoolTensor:$output); let results = (outs TFL_BoolTensor:$output);
@ -665,10 +677,9 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, Broadcastable,
}]; }];
let arguments = ( let arguments = (
// TODO: missing Uint8
ins ins
TensorOf<[I1, F32, I32, I64, I8]>:$x, TensorOf<[I1, F32, I32, I64, I8, TFL_Uint8]>:$x,
TensorOf<[I1, F32, I32, I64, I8]>:$y TensorOf<[I1, F32, I32, I64, I8, TFL_Uint8]>:$y
); );
let results = (outs TFL_BoolTensor:$output); let results = (outs TFL_BoolTensor:$output);
@ -1066,8 +1077,7 @@ def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect]> {
}]; }];
let arguments = (ins let arguments = (ins
// TODO: missing uint8 TensorOf<[F32, I8, I32, I64, TFL_Uint8]>:$input,
TensorOf<[F32, I8, I32, I64]>:$input,
TensorOf<[I32, I64]>:$axis, TensorOf<[I32, I64]>:$axis,
BoolAttr:$keep_dims BoolAttr:$keep_dims
); );
@ -1686,9 +1696,10 @@ def TFL_TanhOp: TFL_Op<"tanh", [
Computes element-wise Hyperbolic tangent of input Computes element-wise Hyperbolic tangent of input
}]; }];
let arguments = (ins AnyTensor:$x); // TODO(haoliang): missing Uint8.
let arguments = (ins TensorOf<[F32, I16, I8]>:$x);
let results = (outs AnyTensor:$y); let results = (outs TensorOf<[F32, I16, I8]>:$y);
} }
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
@ -1958,6 +1969,61 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
let hasOptions = 1; let hasOptions = 1;
} }
def TFL_CastOp : TFL_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> {
let summary = "Cast operator";
let description = [{
Casts input from input type to output type.
}];
// TODO(b/135538711): Add complex types here.
let arguments = (ins
TensorOf<[F32, I1, I32, I64]>:$input
);
let results = (outs TensorOf<[F32, I1, I32, I64]>:$output);
// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.
let hasOptions = 0;
}
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
NoSideEffect, TFL_OperandHasRank<1, 2>]> {
let summary = "MirrorPad Operator. Pads a tensor with mirrored values.";
let description = [{
This operation pads a input with mirrored values according to the paddings
you specify. paddings is an integer tensor with shape [n, 2],
where n is the rank of input.
For each dimension D of input, paddings[D, 0] indicates how many values
to add before the contents of input in that dimension,
and paddings[D, 1] indicates how many values to add after the contents of
input in that dimension.
Both paddings[D, 0] and paddings[D, 1] must be no greater than
input.dim_size(D) (or input.dim_size(D) - 1)
if copy_border is true (if false, respectively).
The padded size of each dimension D of the output is:
paddings(D, 0) + input.dim_size(D) + paddings(D, 1)
}];
let arguments = (ins
// TODO: add uint8 support when ready.
TensorOf<[F32, I32, I64]>:$input,
TensorOf<[I32, I64]>:$pad,
TFL_MirrorPaddingAttr:$mode
);
let results = (outs
TensorOf<[F32, I32, I64]>:$output
);
let hasOptions = 1;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Quantization ops. // Quantization ops.

View File

@ -75,7 +75,7 @@ class FixedResultUniformScale {
Builder builder(op->getContext()); Builder builder(op->getContext());
IntegerType storage_type = builder.getIntegerType(BitWidth); IntegerType storage_type = builder.getIntegerType(BitWidth);
const double scale = static_cast<double>(ScaleMantissa) * const double scale = static_cast<double>(ScaleMantissa) *
::exp10(static_cast<double>(ScaleExp)); ::pow(10.0, static_cast<double>(ScaleExp));
return UniformQuantizedType::getChecked( return UniformQuantizedType::getChecked(
Sign, storage_type, result_type.getElementType(), scale, ZeroPoint, Sign, storage_type, result_type.getElementType(), scale, ZeroPoint,
StorageTypeMin, StorageTypeMax, builder.getUnknownLoc()); StorageTypeMin, StorageTypeMax, builder.getUnknownLoc());

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/SMLoc.h" #include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h" #include "llvm/Support/SourceMgr.h"
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Parser.h" // TF:local_config_mlir #include "mlir/Parser.h" // TF:local_config_mlir
@ -98,7 +99,7 @@ int main(int argc, char** argv) {
if (!module) return 1; if (!module) return 1;
// TODO(jpienaar): Expand to support inputs. // TODO(jpienaar): Expand to support inputs.
mlir::Function main = module->getNamedFunction("main"); mlir::FuncOp main = module->lookupSymbol<mlir::FuncOp>("main");
QCHECK(main) << "No 'main' function specified."; QCHECK(main) << "No 'main' function specified.";
if (main.getType().getNumInputs() != 0) if (main.getType().getNumInputs() != 0)
LOG(QFATAL) << "NYI: Only nullary functions supported."; LOG(QFATAL) << "NYI: Only nullary functions supported.";

View File

@ -43,6 +43,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
return DT_INT64; return DT_INT64;
case toco::IODataType::STRING: case toco::IODataType::STRING:
return DT_STRING; return DT_STRING;
case toco::IODataType::BOOL:
return DT_BOOL;
default: default:
return DT_INVALID; return DT_INVALID;
} }

View File

@ -829,3 +829,39 @@ func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor
// CHECK-LABEL: slice1Tensor // CHECK-LABEL: slice1Tensor
// CHECK: "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32> // CHECK: "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
} }
func @mirror_pad(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
%0 = "tf.MirrorPad"(%arg0, %arg1) { mode = "SYMMETRIC" }: (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
// CHECK-LABEL: mirror_pad
// CHECK: %0 = "tfl.mirror_pad"(%arg0, %arg1) {mode = "SYMMETRIC"} : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<?xf32>
// CHECK: return %0 : tensor<?xf32>
}
func @mirror_pad_reflect(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
%0 = "tf.MirrorPad"(%arg0, %arg1) { mode = "REFLECT" }: (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
// CHECK-LABEL: mirror_pad_reflect
// CHECK: %0 = "tfl.mirror_pad"(%arg0, %arg1) {mode = "REFLECT"} : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<?xf32>
// CHECK: return %0 : tensor<?xf32>
}
func @Tanh(%arg0: tensor<1xf32>) -> tensor<1xf32> {
%2 = "tf.Tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
return %2: tensor<1xf32>
// CHECK-LABEL: Tanh
// CHECK: %0 = "tfl.tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
}
func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> {
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
return %0 : tensor<1x2x2x5xf32>
// CHECK-LABEL: cast
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
}

View File

@ -1,6 +1,7 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(tensor<3x2xi32>) -> tensor<3x2xi32> { func @main(tensor<3x2xi32>) -> tensor<3x2xi32>
attributes {tf.entry_function = {inputs = "input", outputs = "SameNameAsOutput"}} {
^bb0(%arg0: tensor<3x2xi32>): ^bb0(%arg0: tensor<3x2xi32>):
// CHECK: { // CHECK: {
// CHECK-NEXT: version: 3, // CHECK-NEXT: version: 3,
@ -14,7 +15,7 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
// CHECK-NEXT: shape: [ 3, 2 ], // CHECK-NEXT: shape: [ 3, 2 ],
// CHECK-NEXT: type: INT32, // CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 1, // CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "Input", // CHECK-NEXT: name: "input",
// CHECK-NEXT: quantization: { // CHECK-NEXT: quantization: {
// CHECK-EMPTY: // CHECK-EMPTY:
// CHECK-NEXT: } // CHECK-NEXT: }
@ -38,7 +39,7 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
// CHECK-NEXT: shape: [ ], // CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32, // CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 4, // CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "Const2", // CHECK-NEXT: name: "SameNameAsOutput1",
// CHECK-NEXT: quantization: { // CHECK-NEXT: quantization: {
// CHECK-EMPTY: // CHECK-EMPTY:
// CHECK-NEXT: } // CHECK-NEXT: }
@ -46,7 +47,7 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
// CHECK-NEXT: shape: [ ], // CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32, // CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 5, // CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "add", // CHECK-NEXT: name: "SameNameAsOutput",
// CHECK-NEXT: quantization: { // CHECK-NEXT: quantization: {
// CHECK-EMPTY: // CHECK-EMPTY:
// CHECK-NEXT: } // CHECK-NEXT: }
@ -90,7 +91,7 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input") %0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input")
%1 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> loc("Const") %1 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> loc("Const")
%2 = "tfl.sub" (%0, %1) {fused_activation_function = "RELU6"} : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> loc("sub") %2 = "tfl.sub" (%0, %1) {fused_activation_function = "RELU6"} : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> loc("sub")
%3 = "std.constant" () {value = dense<10> : tensor<i32>} : () -> tensor<i32> loc("Const2") %3 = "std.constant" () {value = dense<10> : tensor<i32>} : () -> tensor<i32> loc("SameNameAsOutput")
%4 = "tfl.add" (%3, %2) {fused_activation_function = "NONE"} : (tensor<i32>, tensor<3x2xi32>) -> tensor<3x2xi32> loc("add") %4 = "tfl.add" (%3, %2) {fused_activation_function = "NONE"} : (tensor<i32>, tensor<3x2xi32>) -> tensor<3x2xi32> loc("add")
return %4 : tensor<3x2xi32> return %4 : tensor<3x2xi32>
} }

View File

@ -817,7 +817,7 @@ func @testConcatInvalidOutputElementalType(%arg0: tensor<2xi32>, %arg1: tensor<2
// ----- // -----
func @testConcatInvalidStorageType(%arg0: tensor<2x!quant.uniform<i9:f32, 0.1:128>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>> { func @testConcatInvalidStorageType(%arg0: tensor<2x!quant.uniform<i9:f32, 0.1:128>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>> {
// expected-error @+1 {{'tfl.concatenation' op operand #0 must be tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or quantized type with 8 bits storage type values}} // expected-error @+1 {{'tfl.concatenation' op operand #0 must be tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or quantized type with 8 bits storage type or TFLite uint8 type values}}
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2x!quant.uniform<i9:f32, 0.1:128>>, tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>> %0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2x!quant.uniform<i9:f32, 0.1:128>>, tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1:128>> return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
} }

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/Support/SourceMgr.h" #include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir #include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir #include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
@ -32,8 +33,9 @@ limitations under the License.
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
using mlir::FuncOp;
using mlir::MLIRContext; using mlir::MLIRContext;
using mlir::Module; using mlir::ModuleOp;
using stream_executor::port::StatusOr; using stream_executor::port::StatusOr;
using tensorflow::Status; using tensorflow::Status;
@ -47,7 +49,7 @@ static llvm::cl::opt<bool> print_function_result_mapping(
enum TranslationStatus { kTrSuccess, kTrFailure }; enum TranslationStatus { kTrSuccess, kTrFailure };
static int PrintFunctionResultMapping(const std::string &result, static int PrintFunctionResultMapping(const std::string &result,
Module module) { ModuleOp module) {
// Build model from the resultant string to extract the return values from // Build model from the resultant string to extract the return values from
// their source of truth. // their source of truth.
auto model = auto model =
@ -83,7 +85,7 @@ static int PrintFunctionResultMapping(const std::string &result,
std::cout << '\'' << subgraph_name << "' outputs:\n"; std::cout << '\'' << subgraph_name << "' outputs:\n";
mlir::Operation *terminator = nullptr; mlir::Operation *terminator = nullptr;
if (subgraph->name()) { if (subgraph->name()) {
if (auto fn = module.getNamedFunction(subgraph->name()->str())) if (auto fn = module.lookupSymbol<FuncOp>(subgraph->name()->str()))
terminator = fn.back().getTerminator(); terminator = fn.back().getTerminator();
} }
i = 0; i = 0;

View File

@ -34,7 +34,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
using mlir::MLIRContext; using mlir::MLIRContext;
using mlir::Module; using mlir::ModuleOp;
using mlir::OwningModuleRef; using mlir::OwningModuleRef;
using stream_executor::port::StatusOr; using stream_executor::port::StatusOr;
@ -87,8 +87,8 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
context); context);
} }
bool ShouldRunQuantizePasses(mlir::Module m) { bool ShouldRunQuantizePasses(mlir::ModuleOp m) {
if (mlir::Function main_fn = m.getNamedFunction("main")) { if (mlir::FuncOp main_fn = m.lookupSymbol<mlir::FuncOp>("main")) {
return main_fn.getAttrOfType<mlir::UnitAttr>("tf.quantize") != return main_fn.getAttrOfType<mlir::UnitAttr>("tf.quantize") !=
mlir::Attribute(); mlir::Attribute();
} }
@ -100,6 +100,16 @@ void AddTFToTFLConversionPasses(bool emit_builtin_tflite_ops, bool run_quantize,
bool lower_tensor_list_ops, bool lower_tensor_list_ops,
mlir::PassManager *pass_manager) { mlir::PassManager *pass_manager) {
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass()); pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
if (lower_tensor_list_ops) {
// Execute this pass before `CanonicalizerPass` in case some TensorList
// ops are constant folded into variant types.
// TODO(b/137125056): Move this pass after `CanonicalizerPass` after we
// handle constant ops that produce `TensorList`.
// TODO(haoliang): Add this pass by default.
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
}
// TODO(jpienaar): Revise post dialect constants. // TODO(jpienaar): Revise post dialect constants.
pass_manager->addPass(mlir::TF::CreateDecodeConstantPass()); pass_manager->addPass(mlir::TF::CreateDecodeConstantPass());
// Canonicalization includes const folding, which is utilized here to optimize // Canonicalization includes const folding, which is utilized here to optimize
@ -112,10 +122,6 @@ void AddTFToTFLConversionPasses(bool emit_builtin_tflite_ops, bool run_quantize,
if (emit_builtin_tflite_ops) { if (emit_builtin_tflite_ops) {
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to // Prepare for TFLite dialect, rerun canonicalization, and then legalize to
// the TFLite dialect. // the TFLite dialect.
// TODO(haoliang): Add this pass by default.
if (lower_tensor_list_ops) {
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
}
pass_manager->addPass(mlir::TFL::CreatePrepareTFPass()); pass_manager->addPass(mlir::TFL::CreatePrepareTFPass());
pass_manager->addPass(mlir::createCanonicalizerPass()); pass_manager->addPass(mlir::createCanonicalizerPass());
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass()); pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
@ -132,7 +138,7 @@ void AddTFToTFLConversionPasses(bool emit_builtin_tflite_ops, bool run_quantize,
} }
Status ConvertTFControlFlowToTFLOrFlatbuffer( Status ConvertTFControlFlowToTFLOrFlatbuffer(
mlir::Module module, bool export_to_mlir, bool emit_builtin_tflite_ops, mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops, bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops,
bool lower_tensor_list_ops, std::string *result) { bool lower_tensor_list_ops, std::string *result) {
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(), mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),

View File

@ -46,7 +46,7 @@ LoadFromGraphdefOrMlirSource(
// attribute "tf.quantize" by the importer module. // attribute "tf.quantize" by the importer module.
// TODO(fengliuai): switch to the cmd flag once the flags are moved to this // TODO(fengliuai): switch to the cmd flag once the flags are moved to this
// file with main method. // file with main method.
bool ShouldRunQuantizePasses(mlir::Module m); bool ShouldRunQuantizePasses(mlir::ModuleOp m);
// Add the MLIR passes that convert TF control flow dialect to TF Lite dialect // Add the MLIR passes that convert TF control flow dialect to TF Lite dialect
// to a MLIR `pass_manager`. These passes first raise the control flow in the TF // to a MLIR `pass_manager`. These passes first raise the control flow in the TF
@ -69,7 +69,7 @@ void AddTFToTFLConversionPasses(bool emit_builtin_tflite_ops, bool run_quantize,
// main function, Quantization is applied. If `export_to_mlir` is true, the // main function, Quantization is applied. If `export_to_mlir` is true, the
// result is exported in MLIR text format, otherwise exported in flat buffer. // result is exported in MLIR text format, otherwise exported in flat buffer.
Status ConvertTFControlFlowToTFLOrFlatbuffer( Status ConvertTFControlFlowToTFLOrFlatbuffer(
mlir::Module module, bool export_to_mlir, bool emit_builtin_tflite_ops, mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops, bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops,
bool lower_tensor_list_ops, std::string* result); bool lower_tensor_list_ops, std::string* result);
} // namespace tensorflow } // namespace tensorflow

View File

@ -131,6 +131,7 @@ def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>;
def : Pat<(TF_SliceOp $input, $begin, $size), (TFL_SliceOp $input, $begin, $size)>; def : Pat<(TF_SliceOp $input, $begin, $size), (TFL_SliceOp $input, $begin, $size)>;
def : Pat<(TF_SoftmaxOp $arg), (TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>; def : Pat<(TF_SoftmaxOp $arg), (TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>;
def : Pat<(TF_SqueezeOp $arg, $squeeze_dims), (TFL_SqueezeOp $arg, $squeeze_dims)>; def : Pat<(TF_SqueezeOp $arg, $squeeze_dims), (TFL_SqueezeOp $arg, $squeeze_dims)>;
def : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>; def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>;
def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>; def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
@ -228,18 +229,25 @@ def : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2), (TFL_MeanOp $arg0, $arg1, $a
def : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), (TFL_SumOp $arg, $axes, $arg2)>; def : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), (TFL_SumOp $arg, $axes, $arg2)>;
// TopK in TFL is always sorted so we ignore that attribute here.
def : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted), (TFL_TopKV2Op $input, $k)>;
def : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMinOp $arg0, $arg1, $arg2)>; def : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMinOp $arg0, $arg1, $arg2)>;
def : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMaxOp $arg0, $arg1, $arg2)>; def : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMaxOp $arg0, $arg1, $arg2)>;
def : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceProdOp $arg0, $arg1, $arg2)>; def : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceProdOp $arg0, $arg1, $arg2)>;
def : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
def : Pat<(TF_BatchToSpaceNDOp $input, $block_shape, $crops), (TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>; def : Pat<(TF_BatchToSpaceNDOp $input, $block_shape, $crops), (TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>;
def : Pat<(TF_SpaceToBatchNDOp $input, $block_shape, $paddings), (TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>; def : Pat<(TF_SpaceToBatchNDOp $input, $block_shape, $paddings), (TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>;
def : Pat<(TF_ResizeBilinearOp $images, $size, $align_corners, ConstBoolAttrFalse:$half_pixel_centers), (TFL_ResizeBilinearOp $images, $size, $align_corners)>; def : Pat<(TF_ResizeBilinearOp $images, $size, $align_corners, ConstBoolAttrFalse:$half_pixel_centers), (TFL_ResizeBilinearOp $images, $size, $align_corners)>;
def : Pat<(TF_MirrorPadOp $arg0, $arg1, $cst), (TFL_MirrorPadOp $arg0, $arg1, $cst)>;
def : Pat< def : Pat<
(TF_StridedSliceOp $input, $begin, $end, $strides, $begin_mask, $end_mask, $ellipsis_mask, $new_axis_mask, $shrink_axis_mask), (TF_StridedSliceOp $input, $begin, $end, $strides, $begin_mask, $end_mask, $ellipsis_mask, $new_axis_mask, $shrink_axis_mask),
(TFL_StridedSliceOp $input, $begin, $end, $strides, (TFL_StridedSliceOp $input, $begin, $end, $strides,

View File

@ -72,7 +72,6 @@ DECL_CONVERT_OP(MatMul);
DECL_CONVERT_OP(Pack); DECL_CONVERT_OP(Pack);
DECL_CONVERT_OP(Split); DECL_CONVERT_OP(Split);
DECL_CONVERT_OP(SplitV); DECL_CONVERT_OP(SplitV);
DECL_CONVERT_OP(TopKV2);
DECL_CONVERT_OP(Unpack); DECL_CONVERT_OP(Unpack);
#undef DECL_CONVERT_OP #undef DECL_CONVERT_OP
@ -207,14 +206,6 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
return matchSuccess(); return matchSuccess();
} }
PatternMatchResult ConvertTFTopKV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
// TopK in TFL is always sorted so we ignore that attribute here.
rewriter.replaceOpWithNewOp<TFL::TopKV2Op>(op, op->getOperand(0),
op->getOperand(1));
return matchSuccess();
}
PatternMatchResult ConvertTFUnpackOp::matchAndRewrite( PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const { Operation* op, PatternRewriter& rewriter) const {
auto tf_unpack_op = cast<TF::UnpackOp>(op); auto tf_unpack_op = cast<TF::UnpackOp>(op);
@ -239,7 +230,7 @@ void LegalizeTF::runOnFunction() {
populateWithGenerated(ctx, &patterns); populateWithGenerated(ctx, &patterns);
RewriteListBuilder<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFGatherOp, RewriteListBuilder<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFGatherOp,
ConvertTFGatherV2Op, ConvertTFMatMulOp, ConvertTFPackOp, ConvertTFGatherV2Op, ConvertTFMatMulOp, ConvertTFPackOp,
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFTopKV2Op, ConvertTFSplitOp, ConvertTFSplitVOp,
ConvertTFUnpackOp>::build(patterns, ctx); ConvertTFUnpackOp>::build(patterns, ctx);
applyPatternsGreedily(func, std::move(patterns)); applyPatternsGreedily(func, std::move(patterns));
} }

View File

@ -61,7 +61,7 @@ namespace {
class TensorListPatternRewriter : public PatternRewriter { class TensorListPatternRewriter : public PatternRewriter {
public: public:
explicit TensorListPatternRewriter(Function fn) explicit TensorListPatternRewriter(FuncOp fn)
: PatternRewriter(fn.getBody()) {} : PatternRewriter(fn.getBody()) {}
Operation *createOperation(const OperationState &state) override { Operation *createOperation(const OperationState &state) override {
@ -77,7 +77,7 @@ struct LowerStaticTensorListPass
void runOnModule() override; void runOnModule() override;
// Apply type and op changes within a function. // Apply type and op changes within a function.
LogicalResult RewriteFunction(Function func, LogicalResult RewriteFunction(FuncOp func,
TensorListPatternRewriter *rewriter); TensorListPatternRewriter *rewriter);
// Changes the function type of `cond_func` and `body_func`, and the result // Changes the function type of `cond_func` and `body_func`, and the result
@ -276,8 +276,8 @@ LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType(
auto *context = &getContext(); auto *context = &getContext();
auto module = getModule(); auto module = getModule();
Function cond_func = module.getNamedFunction(while_op->getCond()); FuncOp cond_func = module.lookupSymbol<FuncOp>(while_op->getCond());
Function body_func = module.getNamedFunction(while_op->getBody()); FuncOp body_func = module.lookupSymbol<FuncOp>(while_op->getBody());
if (cond_func) { if (cond_func) {
// Change `cond_func`'s argument types to `unranked_argument_types`. // Change `cond_func`'s argument types to `unranked_argument_types`.
@ -327,7 +327,7 @@ LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType(
} }
LogicalResult LowerStaticTensorListPass::RewriteFunction( LogicalResult LowerStaticTensorListPass::RewriteFunction(
Function func, TensorListPatternRewriter *rewriter) { FuncOp func, TensorListPatternRewriter *rewriter) {
auto *context = &getContext(); auto *context = &getContext();
for (Block &block : func) { for (Block &block : func) {
@ -388,7 +388,7 @@ void LowerStaticTensorListPass::runOnModule() {
// have a potential issue when one function taking a `DT_VARIANT` is processed // have a potential issue when one function taking a `DT_VARIANT` is processed
// before the function that produces the `DT_VARIANT`. We need to carefully // before the function that produces the `DT_VARIANT`. We need to carefully
// order the functions to be processed. // order the functions to be processed.
std::vector<Function> funcs_in_module; std::vector<FuncOp> funcs_in_module;
for (auto func : getModule().getOps<FuncOp>()) { for (auto func : getModule().getOps<FuncOp>()) {
// Always place the main function to be the first in the list. // Always place the main function to be the first in the list.
if (func.getName() == "main") { if (func.getName() == "main") {

View File

@ -48,7 +48,7 @@ class PostQuantizePass : public FunctionPass<PostQuantizePass> {
bool emit_quant_adaptor_ops_; bool emit_quant_adaptor_ops_;
}; };
void RemoveQuantizationAdaptorOps(Function func) { void RemoveQuantizationAdaptorOps(FuncOp func) {
mlir::OpBuilder builder(func.getBody()); mlir::OpBuilder builder(func.getBody());
auto& bb = func.getBlocks().front(); auto& bb = func.getBlocks().front();
auto* terminator = bb.getTerminator(); auto* terminator = bb.getTerminator();

View File

@ -50,52 +50,19 @@ struct QuantizePass : public FunctionPass<QuantizePass> {
#include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc" #include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc"
struct QuantizeConcatOp : public RewritePattern {
explicit QuantizeConcatOp(MLIRContext* context)
: RewritePattern(QuantizeOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override;
};
PatternMatchResult mlir::TFL::QuantizeConcatOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto quantize_op = cast<QuantizeOp>(op);
auto concat_op =
dyn_cast_or_null<ConcatenationOp>(quantize_op.input()->getDefiningOp());
if (!concat_op) {
return matchFailure();
}
SmallVector<Value*, 4> values;
values.reserve(concat_op.getNumOperands());
for (auto operand : concat_op.values()) {
if (auto opInst =
dyn_cast_or_null<DequantizeOp>(operand->getDefiningOp())) {
values.push_back(opInst.input());
} else {
return matchFailure();
}
}
rewriter.replaceOpWithNewOp<TFL::ConcatenationOp>(
op, quantize_op.output()->getType(), values,
rewriter.getI32IntegerAttr(concat_op.axis().getZExtValue()),
rewriter.getStringAttr(concat_op.fused_activation_function()));
return matchSuccess();
}
void QuantizePass::runOnFunction() { void QuantizePass::runOnFunction() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto func = getFunction(); auto func = getFunction();
auto* ctx = func.getContext(); auto* ctx = func.getContext();
TFL::populateWithGenerated(ctx, &patterns); TFL::populateWithGenerated(ctx, &patterns);
mlir::RewriteListBuilder<mlir::TFL::QuantizeConcatOp>::build(patterns, ctx); mlir::RewriteListBuilder<mlir::TFL::GenericFullQuantizationPattern<
mlir::TFL::QuantizeOp, mlir::TFL::DequantizeOp>>::build(patterns, ctx);
applyPatternsGreedily(func, std::move(patterns)); applyPatternsGreedily(func, std::move(patterns));
} }
} // namespace } // namespace
// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass. // Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
FunctionPassBase *CreateQuantizePass() { return new QuantizePass(); } FunctionPassBase* CreateQuantizePass() { return new QuantizePass(); }
static PassRegistration<QuantizePass> pass( static PassRegistration<QuantizePass> pass(
"tfl-quantize", "Apply quantization on models in TensorFlow Lite dialect"); "tfl-quantize", "Apply quantization on models in TensorFlow Lite dialect");

View File

@ -22,10 +22,6 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
// Quantize attribute $0 by using quantization parameter from %1. // Quantize attribute $0 by using quantization parameter from %1.
def QuantizeByQuantizedType : NativeCodeCall<"Quantize($0, $1.getValue())">; def QuantizeByQuantizedType : NativeCodeCall<"Quantize($0, $1.getValue())">;
// Call the generic builder of `op`. Use the result type of $0 in the new op.
class ReplaceWith<string op> : NativeCodeCall<"$_builder.create<" # op #
">($0->getLoc(), $0->getResult(0)->getType(), $1, $2, $3)">;
// Squash tfl.dequantize and tfl.quantize pairs. // Squash tfl.dequantize and tfl.quantize pairs.
// TODO(fengliuai): Compare the scale of input and output. This can also be // TODO(fengliuai): Compare the scale of input and output. This can also be
// squashed to a requantize op if the scales are different. // squashed to a requantize op if the scales are different.
@ -39,98 +35,3 @@ def : Pat<(TFL_QuantizeOp
(TFL_QConstOp (TFL_QConstOp
$qtype, $qtype,
(QuantizeByQuantizedType $value, $qtype))>; (QuantizeByQuantizedType $value, $qtype))>;
// Quantize the AddOp if both inputs are dequantized and the output is
// quantized.
def : Pat<(TFL_QuantizeOp:$q
(TFL_AddOp (TFL_DequantizeOp $lhs), (TFL_DequantizeOp $rhs),
$fused_activation_function),
$output_type),
(ReplaceWith<"TFL::AddOp"> $q, $lhs, $rhs,
$fused_activation_function)>;
// Quantize the Conv2DOp if the input and weight are dequantized. The scale of
// the bias input is determined by the scales of input and weight operands.
def : Pat<(TFL_QuantizeOp
(TFL_Conv2DOp
(TFL_DequantizeOp $in),
(TFL_DequantizeOp $weight),
(TFL_DequantizeOp $bias),
$dilation_h_factor,
$dilation_w_factor,
$fused_activation_function,
$padding,
$stride_h,
$stride_w),
$output_type),
(TFL_Conv2DOp
$in,
$weight,
$bias,
$dilation_h_factor,
$dilation_w_factor,
$fused_activation_function,
$padding,
$stride_h,
$stride_w)>;
// Quantize the DepthwiseConv2DOp if the input and weight are dequantized. The
// scale of the bias input is determined by the scales of input and weight
// operands.
def : Pat<(TFL_QuantizeOp
(TFL_DepthwiseConv2DOp
(TFL_DequantizeOp $in),
(TFL_DequantizeOp $weight),
(TFL_DequantizeOp $bias),
$dilation_h_factor,
$dilation_w_factor,
$fused_activation_function,
$padding,
$stride_h,
$stride_w,
$multiplier),
$output_type),
(TFL_DepthwiseConv2DOp
$in,
$weight,
$bias,
$dilation_h_factor,
$dilation_w_factor,
$fused_activation_function,
$padding,
$stride_h,
$stride_w,
$multiplier)>;
// Quantize the ReshapeOp if the input is dequantized and output is quantized.
// The pre-quantize pass can guarantee both quantization parameters are the
// same.
def : Pat<(TFL_QuantizeOp (TFL_ReshapeOp (TFL_DequantizeOp $in)), $output_type),
(TFL_ReshapeOp $in)>;
// Quantize the ReshapeOp if the input is dequantized and output is quantized.
// The pre-quantize pass has set the output quantization parameters to a
// pre-defined value.
def : Pat<(TFL_QuantizeOp (TFL_SoftmaxOp (TFL_DequantizeOp $in), $beta),
$output_type),
(TFL_SoftmaxOp $in, $beta)>;
// Quantize the AveragePool2DOp if the input is dequantized and output is
// quantized. The pre-quantize pass can guarantee both quantization parameters
// are the same.
def : Pat<(TFL_QuantizeOp (TFL_AveragePool2DOp (TFL_DequantizeOp $in),
$filter_height, $filter_width, $fused_activation_function,
$padding, $stride_h, $stride_w), $output_type),
(TFL_AveragePool2DOp $in,
$filter_height, $filter_width, $fused_activation_function,
$padding, $stride_h, $stride_w)>;
// Quantize the MaxPool2DOp if the input is dequantized and output is
// quantized. The pre-quantize pass can guarantee both quantization parameters
// are the same.
def : Pat<(TFL_QuantizeOp (TFL_MaxPool2DOp (TFL_DequantizeOp $in),
$padding, $stride_w, $tride_h, $stride_width, $stride_height,
$fused_activation_function), $output_type),
(TFL_MaxPool2DOp $in,
$padding, $stride_w, $tride_h, $stride_width, $stride_height,
$fused_activation_function)>;

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Matchers.h" // TF:local_config_mlir #include "mlir/IR/Matchers.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir
@ -121,7 +122,7 @@ struct RequantizeState {
// //
class QuantizationDriver { class QuantizationDriver {
public: public:
explicit QuantizationDriver(Function fn) : builder_(fn.getBody()) {} explicit QuantizationDriver(FuncOp fn) : builder_(fn.getBody()) {}
// The entry point of the quantization parameters propagation. // The entry point of the quantization parameters propagation.
void Run(); void Run();
@ -706,7 +707,7 @@ void QuantizationDriver::Run() {
} }
} }
void ApplyQuantizationParamsPropagation(mlir::Function func) { void ApplyQuantizationParamsPropagation(mlir::FuncOp func) {
QuantizationDriver(func).Run(); QuantizationDriver(func).Run();
} }

View File

@ -20,12 +20,66 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
namespace mlir { namespace mlir {
namespace TFL { namespace TFL {
// A generic rewrite pattern which matches any N-in-1-out operations with
// quantization parameters propagated to all the operands and results values.
// The quantization parameters are annotated by the Q/DQ op pairs. Each matched
// pattern are rewritten by its quantized alternatives.
//
// This pattern assumes all the matched ops are quantizable. This assumption is
// always right, except when a "Q" op is used as a requantize op. For non-"Q"
// ops, quantization parameters should be propagated to their result.
//
// This pattern only matches ops which only have one result.
template <typename Q, typename DQ>
struct GenericFullQuantizationPattern : public RewritePattern {
explicit GenericFullQuantizationPattern(MLIRContext* context)
: RewritePattern(Q::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override {
if (op->getNumResults() != 1) {
return matchFailure();
}
auto quantize_op = cast<Q>(op);
auto quantized_op = quantize_op.input()->getDefiningOp();
// If it is a block argument, requantize op, or has more than one result, we
// shouldn't rewrite this op.
if (!quantized_op || llvm::isa<Q>(quantized_op) ||
llvm::isa<DQ>(quantized_op) || quantized_op->getNumResults() != 1) {
return matchFailure();
}
// Collect all the quantized inputs and "clone" the matched op by these
// inputs.
SmallVector<Value*, 4> inputs;
inputs.reserve(quantized_op->getNumOperands());
for (int i = 0, e = quantized_op->getNumOperands(); i != e; ++i) {
auto* operand = quantized_op->getOperand(i);
if (auto op_inst = dyn_cast_or_null<DQ>(operand->getDefiningOp())) {
inputs.push_back(op_inst.input());
} else {
return matchFailure();
}
}
// Use OpBuilder so we can use op name to create the new op.
OpBuilder builder(quantized_op);
OperationState new_state(
quantized_op->getLoc(), quantized_op->getName().getStringRef(), inputs,
op->getResult(0)->getType(), quantized_op->getAttrs());
Operation* new_op = builder.createOperation(new_state);
rewriter.replaceOp(op, {new_op->getResult(0)});
return matchSuccess();
}
};
// Converts the min/max/storage_type/narrow_range information to a // Converts the min/max/storage_type/narrow_range information to a
// QuantizedType, and then returns the attribute containing the QuantizedType. // QuantizedType, and then returns the attribute containing the QuantizedType.
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, FloatAttr min, TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, FloatAttr min,
@ -62,7 +116,7 @@ quant::QuantizedType GetUniformQuantizedTypeForBias(
// quantization parameters are stored as adjacent quantize and dequantize ops // quantization parameters are stored as adjacent quantize and dequantize ops
// and the propagation results are materialized by inserting pairs of quantize // and the propagation results are materialized by inserting pairs of quantize
// and dequantize ops to this function. // and dequantize ops to this function.
void ApplyQuantizationParamsPropagation(mlir::Function func); void ApplyQuantizationParamsPropagation(mlir::FuncOp func);
} // end namespace TFL } // end namespace TFL
} // end namespace mlir } // end namespace mlir

View File

@ -103,6 +103,7 @@ cc_library(
"transforms/generated_optimize.inc", "transforms/generated_optimize.inc",
"transforms/optimize.cc", "transforms/optimize.cc",
"transforms/raise_control_flow.cc", "transforms/raise_control_flow.cc",
"translate/control_to_executor_dialect.cc",
], ],
hdrs = [ hdrs = [
"ir/control_flow_ops.h", "ir/control_flow_ops.h",
@ -280,6 +281,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc", "//tensorflow/core:protos_all_proto_cc",
"@local_config_mlir//:IR", "@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
], ],
) )
@ -471,6 +473,7 @@ cc_library(
"@llvm//:support", "@llvm//:support",
"@local_config_mlir//:IR", "@local_config_mlir//:IR",
"@local_config_mlir//:Parser", "@local_config_mlir//:Parser",
"@local_config_mlir//:Pass",
], ],
) )

View File

@ -694,7 +694,7 @@ void Print(NextIterationSinkOp next_iteration, OpAsmPrinter *p) {
*p << next_iteration.getOperationName() << " ["; *p << next_iteration.getOperationName() << " [";
p->printOperand(next_iteration.getOperand(0)); p->printOperand(next_iteration.getOperand(0));
*p << "] "; *p << "] ";
p->printOperand(next_iteration.getOperand(1)); p->printOperands(llvm::drop_begin(next_iteration.getOperands(), 1));
*p << " : " << next_iteration.getOperand(1)->getType(); *p << " : " << next_iteration.getOperand(1)->getType();
p->printOptionalAttrDict(next_iteration.getAttrs()); p->printOptionalAttrDict(next_iteration.getAttrs());
} }

View File

@ -250,25 +250,6 @@ def TfExecutor_SwitchOp : TfExecutor_Op<"Switch",
TfeControlType: $control TfeControlType: $control
); );
let builders = [OpBuilder<
"Builder *builder, OperationState *result, ArrayRef<Value *> operands = {}",
[{
assert(operands.size() >= 2 && "tf_executor.Switch builder expects at "
"least two operands");
return build(builder, result, operands[0], operands[1], operands.drop_front(2));
}]>,
OpBuilder<
"Builder *builder, OperationState *result, Value *data, Value *predicate, ArrayRef<Value *> controls = {}",
[{
Type dataTy = data->getType();
Type controlTy = ControlType::get(builder->getContext());
result->types = { dataTy, dataTy, controlTy };
result->operands.push_back(data);
result->operands.push_back(predicate);
result->operands.insert(result->operands.end(), controls.begin(), controls.end());
}]>
];
let verifier = ?; let verifier = ?;
} }
@ -311,7 +292,6 @@ def TfExecutor_SwitchNOp :
Variadic<AnyType>:$outputs, Variadic<AnyType>:$outputs,
TfeControlType: $control TfeControlType: $control
); );
} }
def TfExecutor_MergeOp : TfExecutor_Op<"Merge", [NoSideEffect, ControlOperandsAfterAllData]> { def TfExecutor_MergeOp : TfExecutor_Op<"Merge", [NoSideEffect, ControlOperandsAfterAllData]> {
@ -346,18 +326,6 @@ def TfExecutor_MergeOp : TfExecutor_Op<"Merge", [NoSideEffect, ControlOperandsAf
TensorOf<[I32]>:$valueIndex, TensorOf<[I32]>:$valueIndex,
TfeControlType:$control TfeControlType:$control
); );
let builders = [OpBuilder<
"Builder *builder, OperationState *result, ArrayRef<Value *> operands",
[{
assert(operands.size() >= 1 && "tf_executor.Merge builder expects at "
"least one operand");
Type data_type = operands[0]->getType();
Type control_type = ControlType::get(builder->getContext());
result->types = { data_type, builder->getIntegerType(32), control_type};
result->operands.append(operands.begin(), operands.end());
}]>
];
} }
def TfExecutor_EnterOp : TfExecutor_Op<"Enter", def TfExecutor_EnterOp : TfExecutor_Op<"Enter",
@ -408,19 +376,6 @@ def TfExecutor_EnterOp : TfExecutor_Op<"Enter",
); );
let verifier = ?; let verifier = ?;
let builders = [OpBuilder<
"Builder *builder, OperationState *result, ArrayRef<Value *> operands",
[{
assert(operands.size() >= 1 && "tf_executor.Enter builder "
"expects at least one operand");
result->operands.append(operands.begin(), operands.end());
Type control_type = ControlType::get(builder->getContext());
result->types.push_back(operands[0]->getType());
result->types.push_back(control_type);
}]>
];
} }
def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", [NoSideEffect]> { def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", [NoSideEffect]> {
@ -472,12 +427,14 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", [No
); );
let builders = [OpBuilder< let builders = [OpBuilder<
"Builder *builder, OperationState *result, Type resultTy, ArrayRef<Value *> controlInputs = {}", "Builder *builder, OperationState *result, Type result_type, "
"ArrayRef<Value *> control_inputs = {}, ArrayRef<NamedAttribute> attributes = {}",
[{ [{
Type tokenTy = TokenType::get(builder->getContext()); Type token_type = TokenType::get(builder->getContext());
Type controlTy = ControlType::get(builder->getContext()); Type control_type = ControlType::get(builder->getContext());
result->types = { resultTy, tokenTy, controlTy }; result->types = { result_type, token_type, control_type };
result->operands.append(controlInputs.begin(), controlInputs.end()); result->operands.append(control_inputs.begin(), control_inputs.end());
result->attributes.append(attributes.begin(), attributes.end());
}]> }]>
]; ];
} }
@ -527,6 +484,19 @@ def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink"> {
// Optional extra control inputs. // Optional extra control inputs.
Variadic<TfeControlType>:$controlInputs Variadic<TfeControlType>:$controlInputs
); );
let builders = [OpBuilder<
"Builder *builder, OperationState *result, Value *token, "
"ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes = {}",
[{
assert(operands.size() >= 1 && "tf_executor.NextIteration.Sink builder "
"expects at least one operand");
result->operands.push_back(token);
result->operands.insert(result->operands.end(), operands.begin(),
operands.end());
result->attributes.append(attributes.begin(), attributes.end());
}]>
];
} }
def TfExecutor_ExitOp : TfExecutor_Op<"Exit", def TfExecutor_ExitOp : TfExecutor_Op<"Exit",
@ -590,6 +560,20 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger", [NoSideEffect]
); );
let verifier = ?; let verifier = ?;
let builders = [OpBuilder<
"Builder *builder, OperationState *result, "
"ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes = {}",
[{
assert(operands.size() >= 1 && "tf_executor.ControlTrigger builder "
"expects at least one operand");
result->operands.insert(result->operands.end(), operands.begin(),
operands.end());
Type control_type = ControlType::get(builder->getContext());
result->types = {control_type};
result->attributes.append(attributes.begin(), attributes.end());
}]>
];
} }
def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond", [NoSideEffect]> { def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond", [NoSideEffect]> {

View File

@ -79,6 +79,12 @@ def TF_AddNOp : TF_Op<"AddN", [Commutative, NoSideEffect]> {
let summary = "Add all input tensors element wise."; let summary = "Add all input tensors element wise.";
let description = [{ let description = [{
Inputs must be of same size and shape.
```python
x = [9, 7, 10]
tf.math.add_n(x) ==> 26
```
}]; }];
let arguments = (ins let arguments = (ins
@ -467,6 +473,15 @@ def TF_CosOp : TF_Op<"Cos", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes cos of x element-wise."; let summary = "Computes cos of x element-wise.";
let description = [{ let description = [{
Given an input tensor, this function computes cosine of every
element in the tensor. Input range is `(-inf, inf)` and
output range is `[-1,1]`. If input lies outside the boundary, `nan`
is returned.
```python
x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 200, 10000, float("inf")])
tf.math.cos(x) ==> [nan -0.91113025 0.87758255 0.5403023 0.36235774 0.48718765 -0.95215535 nan]
```
}]; }];
let arguments = (ins let arguments = (ins
@ -1027,6 +1042,43 @@ Invert (flip) each bit of supported types; for example, type `uint8` value 01010
let description = [{ let description = [{
Flip each bit of supported types. For example, type `int8` (decimal 2) binary 00000010 becomes (decimal -3) binary 11111101. Flip each bit of supported types. For example, type `int8` (decimal 2) binary 00000010 becomes (decimal -3) binary 11111101.
This operation is performed on each element of the tensor argument `x`. This operation is performed on each element of the tensor argument `x`.
Example:
```python
import tensorflow as tf
from tensorflow.python.ops import bitwise_ops
# flip 2 (00000010) to -3 (11111101)
tf.assert_equal(-3, bitwise_ops.invert(2))
dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64]
inputs = [0, 5, 3, 14]
for dtype in dtype_list:
# Because of issues with negative numbers, let's test this indirectly.
# 1. invert(a) and a = 0
# 2. invert(a) or a = invert(0)
input_tensor = tf.constant([0, 5, 3, 14], dtype=dtype)
not_a_and_a, not_a_or_a, not_0 = [bitwise_ops.bitwise_and(
input_tensor, bitwise_ops.invert(input_tensor)),
bitwise_ops.bitwise_or(
input_tensor, bitwise_ops.invert(input_tensor)),
bitwise_ops.invert(
tf.constant(0, dtype=dtype))]
expected = tf.constant([0, 0, 0, 0], dtype=tf.float32)
tf.assert_equal(tf.cast(not_a_and_a, tf.float32), expected)
expected = tf.cast([not_0] * 4, tf.float32)
tf.assert_equal(tf.cast(not_a_or_a, tf.float32), expected)
# For unsigned dtypes let's also check the result directly.
if dtype.is_unsigned:
inverted = bitwise_ops.invert(input_tensor)
expected = tf.constant([dtype.max - x for x in inputs], dtype=tf.float32)
tf.assert_equal(tf.cast(inverted, tf.float32), tf.cast(expected, tf.float32))
```
}]; }];
let arguments = (ins let arguments = (ins
@ -1348,6 +1400,52 @@ def TF_MinimumOp : TF_Op<"Minimum", [Broadcastable, NoSideEffect]>,
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_MirrorPadOp : TF_Op<"MirrorPad", [NoSideEffect]> {
let summary = "Pads a tensor with mirrored values.";
let description = [{
This operation pads a `input` with mirrored values according to the `paddings`
you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is
the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
how many values to add before the contents of `input` in that dimension, and
`paddings[D, 1]` indicates how many values to add after the contents of `input`
in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater
than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true
(if false, respectively).
The padded size of each dimension D of the output is:
`paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
For example:
```
# 't' is [[1, 2, 3], [4, 5, 6]].
# 'paddings' is [[1, 1]], [2, 2]].
# 'mode' is SYMMETRIC.
# rank of 't' is 2.
pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2]
[2, 1, 1, 2, 3, 3, 2]
[5, 4, 4, 5, 6, 6, 5]
[5, 4, 4, 5, 6, 6, 5]]
```
}];
let arguments = (ins
TF_Tensor:$input,
TF_I32OrI64Tensor:$paddings,
TF_AnyStrAttrOf<["REFLECT", "SYMMETRIC"]>:$mode
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
}
def TF_MulOp : TF_Op<"Mul", [Broadcastable, Commutative, NoSideEffect]>, def TF_MulOp : TF_Op<"Mul", [Broadcastable, Commutative, NoSideEffect]>,
WithBroadcastableBinOpBuilder { WithBroadcastableBinOpBuilder {
let summary = "Returns x * y element-wise."; let summary = "Returns x * y element-wise.";
@ -2194,9 +2292,17 @@ Specifically, `y = 1 / (1 + exp(-x))`.
} }
def TF_SinOp : TF_Op<"Sin", [NoSideEffect, SameOperandsAndResultType]> { def TF_SinOp : TF_Op<"Sin", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes sin of x element-wise."; let summary = "Computes sine of x element-wise.";
let description = [{ let description = [{
Given an input tensor, this function computes sine of every
element in the tensor. Input range is `(-inf, inf)` and
output range is `[-1,1]`.
```python
x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 200, 10, float("inf")])
tf.math.sin(x) ==> [nan -0.4121185 -0.47942555 0.84147096 0.9320391 -0.87329733 -0.54402107 nan]
```
}]; }];
let arguments = (ins let arguments = (ins
@ -2591,6 +2697,31 @@ retained with length 1.
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
} }
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes hyperbolic tangent of `x` element-wise.";
let description = [{
Given an input tensor, this function computes hyperbolic tangent of every
element in the tensor. Input range is `[-inf, inf]` and
output range is `[-1,1]`.
```python
x = tf.constant([-float("inf"), -5, -0.5, 1, 1.2, 2, 3, float("inf")])
tf.math.tanh(x) ==> [-1. -0.99990916 -0.46211717 0.7615942 0.8336547 0.9640276 0.9950547 1.]
```
}];
let arguments = (ins
TF_FpOrComplexTensor:$x
);
let results = (outs
TF_FpOrComplexTensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_TensorListFromTensorOp : TF_Op<"TensorListFromTensor", [NoSideEffect]> { def TF_TensorListFromTensorOp : TF_Op<"TensorListFromTensor", [NoSideEffect]> {
let summary = [{ let summary = [{
Creates a TensorList which, when stacked, has the value of `tensor`. Creates a TensorList which, when stacked, has the value of `tensor`.

View File

@ -275,18 +275,18 @@ static LogicalResult Verify(FusedBatchNormOp op) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult IfOp::verify() { LogicalResult IfOp::verify() {
auto thenAttr = getAttrOfType<FunctionAttr>("then_branch"); auto thenAttr = getAttrOfType<SymbolRefAttr>("then_branch");
if (!thenAttr) return emitOpError("requires then_branch attribute"); if (!thenAttr) return emitOpError("requires then_branch attribute");
auto elseAttr = getAttrOfType<FunctionAttr>("else_branch"); auto elseAttr = getAttrOfType<SymbolRefAttr>("else_branch");
if (!elseAttr) return emitOpError("requires else_branch attribute"); if (!elseAttr) return emitOpError("requires else_branch attribute");
auto module = getParentOfType<Module>(); auto module = getParentOfType<ModuleOp>();
auto thenFn = module.getNamedFunction(thenAttr.getValue()); auto thenFn = module.lookupSymbol<FuncOp>(thenAttr.getValue());
if (!thenFn) if (!thenFn)
return emitOpError("then_branch refers to an undefined function : ") return emitOpError("then_branch refers to an undefined function : ")
<< thenAttr; << thenAttr;
auto elseFn = module.getNamedFunction(elseAttr.getValue()); auto elseFn = module.lookupSymbol<FuncOp>(elseAttr.getValue());
if (!elseFn) if (!elseFn)
return emitOpError("else_branch refers to an undefined function : ") return emitOpError("else_branch refers to an undefined function : ")
<< elseAttr; << elseAttr;
@ -627,9 +627,10 @@ OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(SoftmaxOp op) { static LogicalResult Verify(SoftmaxOp op) {
if (!IsOfRankOrUnranked(op.logits(), 2)) if (!IsOfRankOrUnranked(op.logits(), 1) &&
return op.emitOpError("requires operand to be 2D tensor"); !IsOfRankOrUnranked(op.logits(), 2)) {
return op.emitOpError("requires operand to be 1D/2D tensor");
}
return success(); return success();
} }
@ -727,20 +728,20 @@ void TruncateDivOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult WhileOp::verify() { LogicalResult WhileOp::verify() {
auto condAttr = getAttrOfType<FunctionAttr>("cond"); auto condAttr = getAttrOfType<SymbolRefAttr>("cond");
if (!condAttr) return emitOpError("requires cond attribute"); if (!condAttr) return emitOpError("requires cond attribute");
auto module = getParentOfType<Module>(); auto module = getParentOfType<ModuleOp>();
auto condFn = module.getNamedFunction(condAttr.getValue()); auto condFn = module.lookupSymbol<FuncOp>(condAttr.getValue());
auto condFuncType = condFn.getType(); auto condFuncType = condFn.getType();
// Verify that the cond function has exactly one result. // Verify that the cond function has exactly one result.
if (condFuncType.getNumResults() != 1) if (condFuncType.getNumResults() != 1)
return emitOpError("requires cond function to have exactly one result"); return emitOpError("requires cond function to have exactly one result");
auto bodyAttr = getAttrOfType<FunctionAttr>("body"); auto bodyAttr = getAttrOfType<SymbolRefAttr>("body");
if (!bodyAttr) return emitOpError("requires body attribute"); if (!bodyAttr) return emitOpError("requires body attribute");
auto bodyFn = module.getNamedFunction(bodyAttr.getValue()); auto bodyFn = module.lookupSymbol<FuncOp>(bodyAttr.getValue());
auto bodyFuncType = bodyFn.getType(); auto bodyFuncType = bodyFn.getType();
SmallVector<Type, 4> operands(getOperandTypes()); SmallVector<Type, 4> operands(getOperandTypes());

View File

@ -119,11 +119,11 @@ class IfOp : public Op<IfOp, TensorFlowOp, OpTrait::AtLeastNOperands<1>::Impl,
// TODO(b/132271680): This is not following Google naming style // TODO(b/132271680): This is not following Google naming style
StringRef getThen() { StringRef getThen() {
return getAttrOfType<FunctionAttr>("then_branch").getValue(); return getAttrOfType<SymbolRefAttr>("then_branch").getValue();
} }
StringRef getElse() { StringRef getElse() {
return getAttrOfType<FunctionAttr>("else_branch").getValue(); return getAttrOfType<SymbolRefAttr>("else_branch").getValue();
} }
LogicalResult verify(); LogicalResult verify();
@ -157,8 +157,12 @@ class WhileOp : public Op<WhileOp, TensorFlowOp, OpTrait::VariadicOperands,
using Op::Op; using Op::Op;
static StringRef getOperationName() { return "tf.While"; } static StringRef getOperationName() { return "tf.While"; }
StringRef getCond() { return getAttrOfType<FunctionAttr>("cond").getValue(); } StringRef getCond() {
StringRef getBody() { return getAttrOfType<FunctionAttr>("body").getValue(); } return getAttrOfType<SymbolRefAttr>("cond").getValue();
}
StringRef getBody() {
return getAttrOfType<SymbolRefAttr>("body").getValue();
}
LogicalResult verify(); LogicalResult verify();
}; };

View File

@ -105,8 +105,10 @@ retained with length 1.
// In MLIR, the 'tf.Placeholder.input' instruction is used to capture attributes // In MLIR, the 'tf.Placeholder.input' instruction is used to capture attributes
// of function arguments. // of function arguments.
// Note: NoSideEffect trait is not added intentionally to preserve the captured
// attributes even if the input is unused.
def TF_PlaceholderInputOp : TF_Op<"Placeholder.input", def TF_PlaceholderInputOp : TF_Op<"Placeholder.input",
[NoSideEffect, SameOperandsAndResultType]> { [SameOperandsAndResultType]> {
let summary = "PlaceholderInput op"; let summary = "PlaceholderInput op";
let description = [{ let description = [{

View File

@ -19,6 +19,10 @@ limitations under the License.
#ifdef HANDLE_TF_TYPE #ifdef HANDLE_TF_TYPE
// class, enumerant, name // class, enumerant, name
HANDLE_TF_TYPE(Uint8, UINT8, "uint8")
HANDLE_TF_TYPE(Uint16, UINT16, "uint16")
HANDLE_TF_TYPE(Uint32, UINT32, "uint32")
HANDLE_TF_TYPE(Uint64, UINT64, "uint64")
HANDLE_TF_TYPE(Qint8, QINT8, "qint8") HANDLE_TF_TYPE(Qint8, QINT8, "qint8")
HANDLE_TF_TYPE(Qint16, QINT16, "qint16") HANDLE_TF_TYPE(Qint16, QINT16, "qint16")
HANDLE_TF_TYPE(Qint32, QINT32, "qint32") HANDLE_TF_TYPE(Qint32, QINT32, "qint32")

View File

@ -0,0 +1,86 @@
// RUN: tf-opt -tf-control-to-executor-conversion %s | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @islands_with_control
// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: tensor<*xf32>)
func @islands_with_control(tensor<*xf32>) -> tensor<*xf32> {
^bb0(%0: tensor<*xf32>):
%1:2 = "_tf.Identity"(%0) : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
%2 = "_tf.Add"(%0, %0, %1#1) : (tensor<*xf32>, tensor<*xf32>, !_tf.control) -> tensor<*xf32>
return %2 : tensor<*xf32>
}
// CHECK-NEXT: %[[GRAPH:[0-9]*]] = tf_executor.graph {
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island {
// CHECK-NEXT: %{{[0-9]*}} = "tf.Identity"(%[[ARG0]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) {
// CHECK-NEXT: %{{[0-9]*}} = "tf.Add"(%[[ARG0]], %[[ARG0]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: tf_executor.fetch %[[ADD]]#0 : tensor<*xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[GRAPH]] : tensor<*xf32>
// CHECK-LABEL: func @LoopTest() {
func @LoopTest() {
%0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
%1:2 = "_tf.Enter"(%0#0) {T = "tfdtype$DT_INT32", device = "", frame_name = "while/while_context", is_constant = false, name = "while/Enter", parallel_iterations = 10 : i64} : (tensor<i32>) -> (tensor<*xi32>, !_tf.control)
%2 = "_tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> !_tf.control
%3:2 = "_tf.NextIteration.source"() {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : () -> (tensor<*xi32>, !_tf.control)
%4:3 = "_tf.Merge"(%3#0, %1#0) {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<i32>, !_tf.control)
%5:2 = "_tf.Const"(%4#2) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
%6:2 = "_tf.Less"(%4#0, %5#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control)
%7:2 = "_tf.LoopCond"(%6#0) {device = "", name = "while/LoopCond"} : (tensor<*xi1>) -> (tensor<i1>, !_tf.control)
%8:3 = "_tf.Switch"(%4#0, %7#0) {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"} : (tensor<*xi32>, tensor<i1>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
%9:2 = "_tf.Exit"(%8#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
%10:2 = "_tf.Identity"(%8#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
%11:2 = "_tf.Const"(%10#1) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
%12:2 = "_tf.Add"(%10#0, %11#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
%13 = "_tf.ControlTrigger"(%2, %12#1, %9#1) {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"} : (!_tf.control, !_tf.control, !_tf.control) -> !_tf.control
%14 = "_tf.NextIteration.sink"(%12#0, %13) {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : (tensor<*xi32>, !_tf.control) -> (!_tf.control)
return
}
// CHECK-NEXT: tf_executor.graph {
// CHECK-NEXT: %[[CONST:[0-9]*]]:2 = tf_executor.island {
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<i32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[ENTER:[0-9]*]]:2 = tf_executor.Enter %[[CONST]]#0 frame "while/while_context" : (tensor<i32>) -> (tensor<*xi32>, !tf_executor.control) {T = "tfdtype$DT_INT32", device = "", name = "while/Enter"}
// CHECK-NEXT: %[[NOOP:[0-9]*]] = tf_executor.island {
// CHECK-NEXT: "tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> ()
// CHECK-NEXT: tf_executor.yield
// CHECK-NEXT: }
// CHECK-NEXT: %[[NEXTIT_SRC:[0-9]*]]:3 = tf_executor.NextIteration.Source : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
// CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = tf_executor.Merge %[[NEXTIT_SRC]]#0, %[[ENTER]]#0 : tensor<*xi32> {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"}
// CHECK-NEXT: %[[CONST_LESS:[0-9]*]]:2 = tf_executor.island(%[[MERGE]]#2) {
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<i32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[LESS:[0-9]*]]:2 = tf_executor.island {
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Less"(%[[MERGE]]#0, %[[CONST_LESS]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<*xi1>
// CHECK-NEXT: }
// CHECK-NEXT: %[[COND:[0-9]*]]:2 = tf_executor.LoopCond %[[LESS:[0-9]*]]#0 : (tensor<*xi1>) -> (tensor<i1>, !tf_executor.control) {device = "", name = "while/LoopCond"}
// CHECK-NEXT: %[[SWITCH:[0-9]*]]:3 = tf_executor.Switch %[[MERGE]]#0, %[[COND]]#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"}
// CHECK-NEXT: %[[EXIT:[0-9]*]]:2 = tf_executor.Exit %[[SWITCH]]#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"}
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island {
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Identity"(%[[SWITCH]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> tensor<*xi32>
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<*xi32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[CONST_ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) {
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<i32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island {
// CHECK-NEXT: %{{[0-9]*}} = "tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xi32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %2, %12#1, %9#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"}
// CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC]]#1] %[[ADD]]#0, %[[CT]] : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
// CHECK-NEXT: tf_executor.fetch
// CHECK-NEXT: }
// CHECK-NEXT: return

View File

@ -2176,14 +2176,12 @@ versions {
# CHECK: %73:2 = "_tf.mul_2_Da30D05wlPU0"(%58#0, %72#0, %47#1) {_tpu_replicate = "cluster", device = "", name = "while/mul_2_Da30D05wlPU"} : (tensor<*xf32>, tensor<*xf32>, !_tf.control) -> (tensor<*xf32>, !_tf.control) # CHECK: %73:2 = "_tf.mul_2_Da30D05wlPU0"(%58#0, %72#0, %47#1) {_tpu_replicate = "cluster", device = "", name = "while/mul_2_Da30D05wlPU"} : (tensor<*xf32>, tensor<*xf32>, !_tf.control) -> (tensor<*xf32>, !_tf.control)
# CHECK: return # CHECK: return
# CHECK-NEXT: } # CHECK-NEXT: }
# CHECK-EMPTY:
# CHECK: func @less_than_5_If8q4vKg9jA0(%arg0: tensor<*xf32>) -> tensor<*xi1> # CHECK: func @less_than_5_If8q4vKg9jA0(%arg0: tensor<*xf32>) -> tensor<*xi1>
# CHECK-NEXT: attributes {tf._noinline = true} { # CHECK-NEXT: attributes {tf._noinline = true} {
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Less/y", value = dense<5.000000e+00> : tensor<f32>} : () -> (tensor<f32>, !_tf.control) # CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Less/y", value = dense<5.000000e+00> : tensor<f32>} : () -> (tensor<f32>, !_tf.control)
# CHECK-NEXT: %1:2 = "_tf.Less"(%arg0, %0#0) {T = "tfdtype$DT_FLOAT", device = "", name = "Less"} : (tensor<*xf32>, tensor<f32>) -> (tensor<*xi1>, !_tf.control) # CHECK-NEXT: %1:2 = "_tf.Less"(%arg0, %0#0) {T = "tfdtype$DT_FLOAT", device = "", name = "Less"} : (tensor<*xf32>, tensor<f32>) -> (tensor<*xi1>, !_tf.control)
# CHECK-NEXT: return %1#0 : tensor<*xi1> # CHECK-NEXT: return %1#0 : tensor<*xi1>
# CHECK-NEXT: } # CHECK-NEXT: }
# CHECK-EMPTY:
# CHECK: func @mul_2_Da30D05wlPU0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> # CHECK: func @mul_2_Da30D05wlPU0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32>
# CHECK-NEXT: attributes {tf._noinline = true} { # CHECK-NEXT: attributes {tf._noinline = true} {
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "mul/y", value = dense<2.000000e+00> : tensor<1x1xf32>} : () -> (tensor<1x1xf32>, !_tf.control) # CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "mul/y", value = dense<2.000000e+00> : tensor<1x1xf32>} : () -> (tensor<1x1xf32>, !_tf.control)

View File

@ -91,8 +91,7 @@ versions {
# CHECK-NEXT: %0:2 = "_tf.PartitionedCall"() {Tin = [], Tout = ["tfdtype$DT_INT32"], config = "", config_proto = "", device = "", executor_type = "", f = @foo0, name = "PartitionedCall"} : () -> (tensor<*xi32>, !_tf.control) # CHECK-NEXT: %0:2 = "_tf.PartitionedCall"() {Tin = [], Tout = ["tfdtype$DT_INT32"], config = "", config_proto = "", device = "", executor_type = "", f = @foo0, name = "PartitionedCall"} : () -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: return # CHECK-NEXT: return
# CHECK-NEXT: } # CHECK-NEXT: }
# CHECK-EMPTY: # CHECK: func @foo0() -> tensor<i32>
# CHECK-NEXT: func @foo0() -> tensor<i32>
# CHECK-NEXT: attributes {tf.experimental_ints_on_device = true} { # CHECK-NEXT: attributes {tf.experimental_ints_on_device = true} {
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<5> : tensor<i32>} : () -> (tensor<i32>, !_tf.control) # CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<5> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
# CHECK-NEXT: %1:2 = "_tf.Identity"(%0#0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor<i32>) -> (tensor<i32>, !_tf.control) # CHECK-NEXT: %1:2 = "_tf.Identity"(%0#0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor<i32>) -> (tensor<i32>, !_tf.control)

View File

@ -157,13 +157,11 @@ versions {
# CHECK-NEXT: %1:2 = "_tf.Case"(%0#0) {Tin = [], Tout = ["tfdtype$DT_FLOAT"], branches = [@foo0, @bar0], device = "", name = "Case", output_shapes = []} : (tensor<i32>) -> (tensor<*xf32>, !_tf.control) # CHECK-NEXT: %1:2 = "_tf.Case"(%0#0) {Tin = [], Tout = ["tfdtype$DT_FLOAT"], branches = [@foo0, @bar0], device = "", name = "Case", output_shapes = []} : (tensor<i32>) -> (tensor<*xf32>, !_tf.control)
# CHECK-NEXT: return # CHECK-NEXT: return
# CHECK-NEXT: } # CHECK-NEXT: }
# CHECK-EMPTY: # CHECK: func @foo0() -> tensor<10xf32> {
# CHECK-NEXT: func @foo0() -> tensor<10xf32> {
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "const_1", value = dense<1.000000e+00> : tensor<10xf32>} : () -> (tensor<10xf32>, !_tf.control) # CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "const_1", value = dense<1.000000e+00> : tensor<10xf32>} : () -> (tensor<10xf32>, !_tf.control)
# CHECK-NEXT: return %0#0 : tensor<10xf32> # CHECK-NEXT: return %0#0 : tensor<10xf32>
# CHECK-NEXT: } # CHECK-NEXT: }
# CHECK-EMPTY: # CHECK: func @bar0() -> tensor<10xf32> {
# CHECK-NEXT: func @bar0() -> tensor<10xf32> {
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "const_2", value = dense<2.000000e+00> : tensor<10xf32>} : () -> (tensor<10xf32>, !_tf.control) # CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "const_2", value = dense<2.000000e+00> : tensor<10xf32>} : () -> (tensor<10xf32>, !_tf.control)
# CHECK-NEXT: return %0#0 : tensor<10xf32> # CHECK-NEXT: return %0#0 : tensor<10xf32>
# CHECK-NEXT: } # CHECK-NEXT: }

View File

@ -526,16 +526,13 @@ versions {
# CHECK-NEXT: %18:2 = "_tf.Identity"(%17#0, %6) {T = "tfdtype$DT_INT32", device = "", name = "output_1_shard_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control) # CHECK-NEXT: %18:2 = "_tf.Identity"(%17#0, %6) {T = "tfdtype$DT_INT32", device = "", name = "output_1_shard_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: return # CHECK-NEXT: return
# CHECK-NEXT: } # CHECK-NEXT: }
# CHECK-EMPTY: # CHECK: func @cond_false0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
# CHECK-NEXT: func @cond_false0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
# CHECK-NEXT: %0:2 = "_tf.Identity"(%arg0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) # CHECK-NEXT: %0:2 = "_tf.Identity"(%arg0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: %1:2 = "_tf.Identity"(%arg1) {T = "tfdtype$DT_INT32", device = "", name = "Identity_1"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) # CHECK-NEXT: %1:2 = "_tf.Identity"(%arg1) {T = "tfdtype$DT_INT32", device = "", name = "Identity_1"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: return %1#0, %0#0 : tensor<*xi32>, tensor<*xi32> # CHECK-NEXT: return %1#0, %0#0 : tensor<*xi32>, tensor<*xi32>
# CHECK-NEXT: } # CHECK-NEXT: }
# CHECK-EMPTY: # CHECK: func @cond_true0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
# CHECK-NEXT: func @cond_true0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
# CHECK-NEXT: %0:2 = "_tf.Identity"(%arg0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) # CHECK-NEXT: %0:2 = "_tf.Identity"(%arg0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: %1:2 = "_tf.Identity"(%arg1) {T = "tfdtype$DT_INT32", device = "", name = "Identity_1"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) # CHECK-NEXT: %1:2 = "_tf.Identity"(%arg1) {T = "tfdtype$DT_INT32", device = "", name = "Identity_1"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: return %0#0, %1#0 : tensor<*xi32>, tensor<*xi32> # CHECK-NEXT: return %0#0, %1#0 : tensor<*xi32>, tensor<*xi32>
# CHECK-NEXT: } # CHECK-NEXT: }
# CHECK-EMPTY:

View File

@ -145,12 +145,10 @@ versions {
#CHECK-NEXT: %2:2 = "_tf.If"(%0#0, %1#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32"], device = "", else_branch = @get_zeros0, name = "If", output_shapes = [], then_branch = @identity0} : (tensor<*xi1>, tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) #CHECK-NEXT: %2:2 = "_tf.If"(%0#0, %1#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32"], device = "", else_branch = @get_zeros0, name = "If", output_shapes = [], then_branch = @identity0} : (tensor<*xi1>, tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
#CHECK-NEXT: return #CHECK-NEXT: return
#CHECK-NEXT: } #CHECK-NEXT: }
#CHECK-EMPTY: #CHECK: func @get_zeros0(%arg0: tensor<*xi32>) -> tensor<2xi32> {
#CHECK-NEXT: func @get_zeros0(%arg0: tensor<*xi32>) -> tensor<2xi32> {
#CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "const", value = dense<[1, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>, !_tf.control) #CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "const", value = dense<[1, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>, !_tf.control)
#CHECK-NEXT: return %0#0 : tensor<2xi32> #CHECK-NEXT: return %0#0 : tensor<2xi32>
#CHECK-NEXT: } #CHECK-NEXT: }
#CHECK-EMPTY: #CHECK: func @identity0(%arg0: tensor<*xi32>) -> tensor<*xi32> {
#CHECK-NEXT: func @identity0(%arg0: tensor<*xi32>) -> tensor<*xi32> {
#CHECK-NEXT: return %arg0 : tensor<*xi32> #CHECK-NEXT: return %arg0 : tensor<*xi32>
#CHECK-NEXT: } #CHECK-NEXT: }

View File

@ -303,16 +303,14 @@ versions {
# CHECK-NEXT: %3:4 = "_tf.While"(%1#0, %2#0, %0#0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_INT32", "tfdtype$DT_INT32"], _lower_using_switch_merge = true, body = @while_body_60, cond = @while_cond_50, device = "", name = "while", output_shapes = ["tfshape$", "tfshape$", "tfshape$"], parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, !_tf.control) # CHECK-NEXT: %3:4 = "_tf.While"(%1#0, %2#0, %0#0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_INT32", "tfdtype$DT_INT32"], _lower_using_switch_merge = true, body = @while_body_60, cond = @while_cond_50, device = "", name = "while", output_shapes = ["tfshape$", "tfshape$", "tfshape$"], parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, !_tf.control)
# CHECK-NEXT: return %3#2 : tensor<i32> # CHECK-NEXT: return %3#2 : tensor<i32>
# CHECK-NEXT: } # CHECK-NEXT: }
# CHECK-EMPTY: # CHECK: func @while_body_60(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) {
# CHECK-NEXT: func @while_body_60(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) {
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Add/y", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control) # CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Add/y", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
# CHECK-NEXT: %1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "add_1/y", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control) # CHECK-NEXT: %1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "add_1/y", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
# CHECK-NEXT: %2:2 = "_tf.Add"(%arg2, %0#0) {T = "tfdtype$DT_INT32", device = "", name = "Add"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control) # CHECK-NEXT: %2:2 = "_tf.Add"(%arg2, %0#0) {T = "tfdtype$DT_INT32", device = "", name = "Add"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: %3:2 = "_tf.Add"(%arg0, %1#0) {T = "tfdtype$DT_INT32", device = "", name = "add_1"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control) # CHECK-NEXT: %3:2 = "_tf.Add"(%arg0, %1#0) {T = "tfdtype$DT_INT32", device = "", name = "add_1"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
# CHECK-NEXT: return %3#0, %arg1, %2#0 : tensor<*xi32>, tensor<*xi32>, tensor<*xi32> # CHECK-NEXT: return %3#0, %arg1, %2#0 : tensor<*xi32>, tensor<*xi32>, tensor<*xi32>
# CHECK-NEXT: } # CHECK-NEXT: }
# CHECK-EMPTY: # CHECK: func @while_cond_50(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>) -> tensor<*xi1> {
# CHECK-NEXT: func @while_cond_50(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>) -> tensor<*xi1> {
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Less/y", value = dense<10> : tensor<i32>} : () -> (tensor<i32>, !_tf.control) # CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Less/y", value = dense<10> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
# CHECK-NEXT: %1:2 = "_tf.Less"(%arg2, %0#0) {T = "tfdtype$DT_INT32", device = "", name = "Less"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control) # CHECK-NEXT: %1:2 = "_tf.Less"(%arg2, %0#0) {T = "tfdtype$DT_INT32", device = "", name = "Less"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control)
# CHECK-NEXT: return %1#0 : tensor<*xi1> # CHECK-NEXT: return %1#0 : tensor<*xi1>

View File

@ -279,14 +279,12 @@ versions {
# CHECK-NEXT: %5:2 = "_tf.SymbolicGradient"(%0#0, %4#0) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], device = "", f = @foo0, f._disable_call_shape_inference = true, name = "gradients/foo_grad/SymbolicGradient"} : (tensor<f32>, tensor<*xf32>) -> (tensor<f32>, !_tf.control) # CHECK-NEXT: %5:2 = "_tf.SymbolicGradient"(%0#0, %4#0) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], device = "", f = @foo0, f._disable_call_shape_inference = true, name = "gradients/foo_grad/SymbolicGradient"} : (tensor<f32>, tensor<*xf32>) -> (tensor<f32>, !_tf.control)
# CHECK-NEXT: return # CHECK-NEXT: return
# CHECK-NEXT: } # CHECK-NEXT: }
# CHECK-EMPTY: # CHECK: func @foo_grad0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32>
# CHECK-NEXT: func @foo_grad0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32>
# CHECK-NEXT: attributes {tf._disable_call_shape_inference = true} { # CHECK-NEXT: attributes {tf._disable_call_shape_inference = true} {
# CHECK-NEXT: %0:2 = "_tf.Mul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "", name = "mul_0"} : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, !_tf.control) # CHECK-NEXT: %0:2 = "_tf.Mul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "", name = "mul_0"} : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
# CHECK-NEXT: return %0#0 : tensor<*xf32> # CHECK-NEXT: return %0#0 : tensor<*xf32>
# CHECK-NEXT: } # CHECK-NEXT: }
# CHECK-EMPTY: # CHECK: func @foo0(%arg0: tensor<*xf32>) -> tensor<*xf32>
# CHECK-NEXT: func @foo0(%arg0: tensor<*xf32>) -> tensor<*xf32>
# CHECK-NEXT: attributes {tf._disable_call_shape_inference = true, tf.gradient = @foo_grad0} { # CHECK-NEXT: attributes {tf._disable_call_shape_inference = true, tf.gradient = @foo_grad0} {
# CHECK-NEXT: %0:2 = "_tf.Exp"(%arg0) {T = "tfdtype$DT_FLOAT", device = "", name = "Exp"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control) # CHECK-NEXT: %0:2 = "_tf.Exp"(%arg0) {T = "tfdtype$DT_FLOAT", device = "", name = "Exp"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
# CHECK-NEXT: %1:2 = "_tf.Neg"(%arg0) {T = "tfdtype$DT_FLOAT", device = "", name = "Neg"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control) # CHECK-NEXT: %1:2 = "_tf.Neg"(%arg0) {T = "tfdtype$DT_FLOAT", device = "", name = "Neg"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)

View File

@ -41,12 +41,10 @@ versions {
# CHECK-NEXT: %1 = "_tf.bar0"() {device = "", name = "unnamed1"} : () -> !_tf.control # CHECK-NEXT: %1 = "_tf.bar0"() {device = "", name = "unnamed1"} : () -> !_tf.control
# CHECK-NEXT: return # CHECK-NEXT: return
# CHECK-NEXT: } # CHECK-NEXT: }
# CHECK-EMPTY: # CHECK: func @foo0() {
# CHECK-NEXT: func @foo0() {
# CHECK-NEXT: %0 = "_tf.bar0"() {device = "", name = "unnamed"} : () -> !_tf.control # CHECK-NEXT: %0 = "_tf.bar0"() {device = "", name = "unnamed"} : () -> !_tf.control
# CHECK-NEXT: return # CHECK-NEXT: return
# CHECK-NEXT: } # CHECK-NEXT: }
# CHECK-EMPTY: # CHECK: func @bar0() {
# CHECK-NEXT: func @bar0() {
# CHECK-NEXT: return # CHECK-NEXT: return
# CHECK-NEXT: } # CHECK-NEXT: }

View File

@ -0,0 +1,111 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -mlir-print-debuginfo %s -o - | FileCheck %s
node {
name: "PartitionedCall"
op: "PartitionedCall"
attr {
key: "Tin"
value {
list {
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_UINT8
}
}
}
attr {
key: "_gradient_op_type"
value {
s: "PartitionedCall-15"
}
}
attr {
key: "config"
value {
s: ""
}
}
attr {
key: "config_proto"
value {
s: "\n\007\n\003GPU\020\000\n\007\n\003CPU\020\0012\002J\0008\001"
}
}
attr {
key: "executor_type"
value {
s: ""
}
}
attr {
key: "f"
value {
func {
name: "__inference_uint_const_14"
}
}
}
}
library {
function {
signature {
name: "__inference_uint_const_14"
output_arg {
name: "identity"
type: DT_UINT8
}
}
node_def {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_UINT8
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_UINT8
tensor_shape {
}
int_val: 5
}
}
}
}
node_def {
name: "Identity"
op: "Identity"
input: "Const:output:0"
attr {
key: "T"
value {
type: DT_UINT8
}
}
}
ret {
key: "identity"
value: "Identity:output:0"
}
}
}
versions {
producer: 29
min_consumer: 12
}
# CHECK: func @main
# CHECK: "_tf.PartitionedCall"()
# CHECK-SAME: Tout = ["tfdtype$DT_UINT8"]
# CHECK-SAME: f = @[[FUNCTION:[A-Za-z0-9_]*]]
# CHECK: func @[[FUNCTION]]() -> tensor<!tf.uint8>
# CHECK: return {{%[0-9]*#[0-9]*}} : tensor<!tf.uint8>

View File

@ -588,7 +588,7 @@ func @testSoftmax(tensor<8x16xf32>) -> tensor<8x16xf32> {
// Test invalid tf.Softmax // Test invalid tf.Softmax
func @testSoftmax(tensor<8x8x8xf32>) -> tensor<8x8x8xf32> { func @testSoftmax(tensor<8x8x8xf32>) -> tensor<8x8x8xf32> {
^bb0(%arg0: tensor<8x8x8xf32>): ^bb0(%arg0: tensor<8x8x8xf32>):
// expected-error @+1 {{requires operand to be 2D tensor}} // expected-error @+1 {{requires operand to be 1D/2D tensor}}
%0 = "tf.Softmax"(%arg0) {T = "tfdtype$DT_FLOAT"} : (tensor<8x8x8xf32>) -> tensor<8x8x8xf32> %0 = "tf.Softmax"(%arg0) {T = "tfdtype$DT_FLOAT"} : (tensor<8x8x8xf32>) -> tensor<8x8x8xf32>
return %0 : tensor<8x8x8xf32> return %0 : tensor<8x8x8xf32>
} }

View File

@ -313,7 +313,7 @@ func @nextiteration_control(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*
%3:3 = tf_executor.NextIteration.Source : tensor<*xf32> %3:3 = tf_executor.NextIteration.Source : tensor<*xf32>
tf_executor.NextIteration.Sink [%3#1] %3#0, %1#2 : tensor<*xf32> tf_executor.NextIteration.Sink [%3#1] %3#0, %1#2 : tensor<*xf32>
// CHECK: %3:3 = tf_executor.NextIteration.Source : tensor<*xf32> // CHECK: %3:3 = tf_executor.NextIteration.Source : tensor<*xf32>
// CHECK: tf_executor.NextIteration.Sink [%3#1] %3#0 : tensor<*xf32> // CHECK: tf_executor.NextIteration.Sink [%3#1] %3#0, %1#2 : tensor<*xf32>
tf_executor.fetch %3#0 : tensor<*xf32> tf_executor.fetch %3#0 : tensor<*xf32>
} }
return %0 : tensor<*xf32> return %0 : tensor<*xf32>

View File

@ -70,7 +70,7 @@ static Value* LowerCondition(Location loc, Value* value, OpBuilder* builder) {
// that is compatible for tensor cast. // that is compatible for tensor cast.
// //
static Operation* CallFn(Location loc, static Operation* CallFn(Location loc,
const std::function<Value*(int)>& get_arg, Function fn, const std::function<Value*(int)>& get_arg, FuncOp fn,
OpBuilder* builder) { OpBuilder* builder) {
FunctionType fn_type = fn.getType(); FunctionType fn_type = fn.getType();
llvm::SmallVector<Value*, 4> operands; llvm::SmallVector<Value*, 4> operands;
@ -153,9 +153,9 @@ static LogicalResult LowerIfOp(IfOp op) {
Value* cond_i1 = LowerCondition(loc, op.getCondition(), &builder); Value* cond_i1 = LowerCondition(loc, op.getCondition(), &builder);
if (!cond_i1) return failure(); if (!cond_i1) return failure();
auto module = op_inst->getParentOfType<Module>(); auto module = op_inst->getParentOfType<ModuleOp>();
auto then_fn = module.getNamedFunction(op.getThen()); auto then_fn = module.lookupSymbol<FuncOp>(op.getThen());
auto else_fn = module.getNamedFunction(op.getElse()); auto else_fn = module.lookupSymbol<FuncOp>(op.getElse());
// Split the basic block before the 'if'. The new dest will be our merge // Split the basic block before the 'if'. The new dest will be our merge
// point. // point.
@ -210,9 +210,9 @@ static LogicalResult LowerWhileOp(WhileOp op) {
OpBuilder builder(op_inst); OpBuilder builder(op_inst);
auto module = op_inst->getParentOfType<Module>(); auto module = op_inst->getParentOfType<ModuleOp>();
auto cond_fn = module.getNamedFunction(op.getCond()); auto cond_fn = module.lookupSymbol<FuncOp>(op.getCond());
auto body_fn = module.getNamedFunction(op.getBody()); auto body_fn = module.lookupSymbol<FuncOp>(op.getBody());
// Split the block containing the While op into two blocks. One containing // Split the block containing the While op into two blocks. One containing
// operations before the While op and other containing the rest. Create two // operations before the While op and other containing the rest. Create two

View File

@ -35,7 +35,6 @@ namespace TFControlFlow {
FunctionPassBase *CreateRaiseTFControlFlowPass(); FunctionPassBase *CreateRaiseTFControlFlowPass();
} // namespace TFControlFlow } // namespace TFControlFlow
} // namespace mlir } // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_ #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_

View File

@ -89,7 +89,7 @@ std::vector<GraphOptimizationPass*> GraphOptPass::FindPassIds() {
} }
void GraphOptPass::runOnModule() { void GraphOptPass::runOnModule() {
mlir::Module module_in = getModule(); mlir::ModuleOp module_in = getModule();
mlir::MLIRContext& ctx = getContext(); mlir::MLIRContext& ctx = getContext();
// Convert MLIR to Graph // Convert MLIR to Graph

View File

@ -0,0 +1,248 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass transforms MLIR TF contol dialect into a combination
// of the TF and TF executor dialects.
//
// !! This code is only intended for migration purpose and will be deleted when
// !! the importer is updated to directly emit the tf_executor dialect.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#define DEBUG_TYPE "tf-ctl-to-executor"
namespace mlir {
namespace {
// This pass checks if a function contains only operations in the TF control
// dialect and converts it to a mix of the tf_executor and tf dialects.
// Operations that exist in the tf_executor dialects are used directly,
// otherwise _tf operations are wrapped in an island and the _ prefix is
// removed. Control dependencies are moved to be handled by the island itself.
struct ControlToExecutorDialectConversion
: public FunctionPass<ControlToExecutorDialectConversion> {
void runOnFunction() override;
private:
tf_executor::IslandOp CreateIslandForOp(Operation *op, OpBuilder *builder);
};
} // end anonymous namespace
static bool IsUnderscoredTFOp(Operation *op) {
return op->getName().getStringRef().startswith("_tf.");
}
static bool HasOnlyTFControlOperations(FuncOp function) {
return llvm::all_of(function, [](Block &block) {
return llvm::all_of(block, [](Operation &op) {
return IsUnderscoredTFOp(&op) || isa<ReturnOp>(op);
});
});
}
tf_executor::IslandOp ControlToExecutorDialectConversion::CreateIslandForOp(
Operation *op, OpBuilder *builder) {
// Create a new region for the tf_executor.island body
SmallVector<Value *, 8> operands;
for (Value *operand : op->getOperands())
if (operand->getType().isa<tf_executor::ControlType>())
operands.push_back(operand);
SmallVector<Type, 8> types;
for (Type result_type : op->getResultTypes())
if (!result_type.isa<TFControlFlow::TFControlType>())
types.push_back(result_type);
types.push_back(tf_executor::ControlType::get(&getContext()));
auto island = builder->create<tf_executor::IslandOp>(
op->getLoc(), types, operands, ArrayRef<NamedAttribute>{});
island.body().push_back(new Block);
return island;
}
void ControlToExecutorDialectConversion::runOnFunction() {
if (!HasOnlyTFControlOperations(getFunction())) {
LLVM_DEBUG(llvm::dbgs() << "Function has unsupported operation, skip "
"tf_executor dialect conversion\n");
return;
}
if (getFunction().getBlocks().size() != 1) {
LLVM_DEBUG(llvm::dbgs() << "Expect single block function, , skip "
"tf_executor dialect conversion\n");
return;
}
Block &body = getFunction().getBody().front();
OpBuilder builder(&body, body.begin());
// Create a new tf_executor.graph at the beginning of the function.
auto graph_op = builder.create<tf_executor::GraphOp>(
getFunction().getLoc(), getFunction().getType().getResults());
graph_op.body().push_back(new Block);
builder.setInsertionPointToEnd(&graph_op.GetBody());
llvm::StringMap<tf_executor::NextIterationSourceOp> frame_name_to_loop;
// Loop over operations in the function and move them into the graph region.
for (Operation &op : llvm::make_early_inc_range(body)) {
// Skip the just-created tf_executor.graph.
if (isa<tf_executor::GraphOp>(op)) continue;
// This is the new operation that will replace the current one in the graph.
Operation *replacement = nullptr;
if (op.isKnownTerminator()) {
// This is the return of the function, we will create a fetch in the graph
// matching the operands of the returns. The return is then updated to
// take as operands the results of the tf_executor.graph operation.
SmallVector<Value *, 8> ret_vals;
for (Value *operand : op.getOperands()) ret_vals.push_back(operand);
for (auto &graph_result : llvm::enumerate(graph_op.getResults()))
op.setOperand(graph_result.index(), graph_result.value());
builder.create<tf_executor::FetchOp>(getFunction().getLoc(), ret_vals);
continue;
}
assert(IsUnderscoredTFOp(&op) && "Expected only _tf operations");
// The operands and types arrays are used to create the tf_executor ops.
SmallVector<Value *, 8> operands;
operands.append(op.getOperands().begin(), op.getOperands().end());
SmallVector<Type, 8> types;
for (Type result_type : op.getResultTypes()) {
if (result_type.isa<TFControlFlow::TFControlType>())
types.push_back(tf_executor::ControlType::get(&getContext()));
else
types.push_back(result_type);
}
auto loc = op.getLoc();
// Match the specific operation that has a tf_executor equivalent, the
// others will be wrapped in an island.
// FIXME: StringSwitch
if (op.getName().getStringRef() == "_tf.Switch") {
replacement = builder.create<tf_executor::SwitchOp>(
loc, types, operands, ArrayRef<NamedAttribute>{});
} else if (op.getName().getStringRef() == "_tf.SwitchN") {
replacement = builder.create<tf_executor::SwitchNOp>(
loc, types, operands, ArrayRef<NamedAttribute>{});
} else if (op.getName().getStringRef() == "_tf.Merge") {
replacement = builder.create<tf_executor::MergeOp>(
loc, types, operands, ArrayRef<NamedAttribute>{});
} else if (op.getName().getStringRef() == "_tf.NextIteration.source") {
replacement = builder.create<tf_executor::NextIterationSourceOp>(
loc, op.getResult(0)->getType(), operands);
// Record a mapping of the name to the nextiteration.source so that when
// we convert the sink we can get the token.
StringAttr frame = op.getAttrOfType<StringAttr>("name");
assert(!frame.getValue().empty());
frame_name_to_loop[frame.getValue()] =
cast<tf_executor::NextIterationSourceOp>(replacement);
// Replace the results here since the _tf source does not produce a token
// there isn't a mapping for the new result #1.
op.getResult(0)->replaceAllUsesWith(replacement->getResult(0));
for (int i : llvm::seq<int>(1, op.getNumResults()))
op.getResult(i)->replaceAllUsesWith(replacement->getResult(i + 1));
replacement->setAttrs(op.getAttrList());
op.erase();
continue;
} else if (op.getName().getStringRef() == "_tf.NextIteration.sink") {
StringAttr frame = op.getAttrOfType<StringAttr>("name");
assert(!frame.getValue().empty());
tf_executor::NextIterationSourceOp srcOp =
frame_name_to_loop[frame.getValue()];
replacement = builder.create<tf_executor::NextIterationSinkOp>(
loc, srcOp.token(), operands, ArrayRef<NamedAttribute>{});
replacement->setAttrs(op.getAttrList());
op.erase();
continue;
} else if (op.getName().getStringRef() == "_tf.LoopCond") {
replacement = builder.create<tf_executor::LoopCondOp>(
loc, types, operands, ArrayRef<NamedAttribute>{});
} else if (op.getName().getStringRef() == "_tf.Enter") {
replacement = builder.create<tf_executor::EnterOp>(
loc, types, operands, ArrayRef<NamedAttribute>{});
} else if (op.getName().getStringRef() == "_tf.Exit") {
replacement = builder.create<tf_executor::ExitOp>(
loc, types, operands, ArrayRef<NamedAttribute>{});
} else if (op.getName().getStringRef() == "_tf.ControlTrigger") {
replacement =
builder.create<tf_executor::ControlTriggerOp>(loc, operands);
} else {
tf_executor::IslandOp island = CreateIslandForOp(&op, &builder);
replacement = island.getOperation();
// General case, drop the leading _ off the name and wrap in an island.
OperationState result(loc, op.getName().getStringRef().drop_front());
// Only the non-control operands are carried over, the island is handling
// the control input.
for (Value *operand : op.getOperands())
if (!operand->getType().isa<tf_executor::ControlType>())
result.operands.push_back(operand);
// Add a result type for each non-control result we find
bool sawControlResult = false;
for (Type result_type : op.getResultTypes()) {
if (result_type.isa<TFControlFlow::TFControlType>()) {
sawControlResult = true;
continue;
}
// We assume all control inputs are at the end of the result list.
assert(!sawControlResult && "all control results must be last");
result.types.push_back(result_type);
}
// Create the operation inside the island
OpBuilder island_builder(&island.GetBody());
Operation *inner_op = island_builder.createOperation(result);
inner_op->setAttrs(op.getAttrList());
// Add the terminator for the island
SmallVector<Value *, 8> ret_vals(inner_op->getResults());
island_builder.create<tf_executor::YieldOp>(loc, ret_vals);
}
// Copy the attributes from the original operation to the replacement and
// remap the results.
if (!isa<tf_executor::IslandOp>(replacement))
replacement->setAttrs(op.getAttrList());
for (int i : llvm::seq<int>(0, op.getNumResults()))
op.getResult(i)->replaceAllUsesWith(replacement->getResult(i));
op.erase();
}
}
FunctionPassBase *CreateTFControlToExecutorDialectConversion() {
return new ControlToExecutorDialectConversion();
}
} // namespace mlir
static mlir::PassRegistration<mlir::ControlToExecutorDialectConversion> pass(
"tf-control-to-executor-conversion",
"Transform from TF control dialect to TF executor dialect.");

View File

@ -88,23 +88,23 @@ class Exporter {
// one entry function, which is identified by name "main". This entry function // one entry function, which is identified by name "main". This entry function
// is converted to the base of the graph graph. The rest of the functions are // is converted to the base of the graph graph. The rest of the functions are
// converted to the library functions in that graph. // converted to the library functions in that graph.
static Status Convert(mlir::Module module, const ExporterConfigs& configs, static Status Convert(mlir::ModuleOp module, const ExporterConfigs& configs,
std::unique_ptr<Graph>* graph, std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def); FunctionLibraryDefinition* flib_def);
// Converts a given Function to a FunctionDef and adds it to the function // Converts a given FuncOp to a FunctionDef and adds it to the function
// definition library // definition library
static Status ConvertLibFunction(const ExporterConfigs& configs, static Status ConvertLibFunction(const ExporterConfigs& configs,
const Dialect* tf_dialect, const Dialect* tf_dialect,
mlir::Function function, mlir::FuncOp function,
FunctionDefLibrary* flib); FunctionDefLibrary* flib);
// Converts the given CFG Function to a Graph. The arguments and returns of // Converts the given FuncOp to a Graph. The arguments and returns of
// function are added to the graph with special op names kArgOp and kRetOp. // function are added to the graph with special op names kArgOp and kRetOp.
// Later on, this graph can be converted a function definition and added to // Later on, this graph can be converted a function definition and added to
// another graph. // another graph.
static StatusOr<std::unique_ptr<Graph>> Convert( static StatusOr<std::unique_ptr<Graph>> Convert(
const ExporterConfigs& configs, const Dialect* tf_dialect, const ExporterConfigs& configs, const Dialect* tf_dialect,
mlir::Function function, FunctionDefLibrary* flib); mlir::FuncOp function, FunctionDefLibrary* flib);
private: private:
explicit Exporter(Graph* graph, const Dialect* tf_dialect) explicit Exporter(Graph* graph, const Dialect* tf_dialect)
@ -376,11 +376,11 @@ Status Exporter::AddNextIterationNode(mlir::Operation* inst) {
StatusOr<std::unique_ptr<Graph>> Exporter::Convert(const ExporterConfigs& confs, StatusOr<std::unique_ptr<Graph>> Exporter::Convert(const ExporterConfigs& confs,
const Dialect* tf_dialect, const Dialect* tf_dialect,
mlir::Function function, mlir::FuncOp function,
FunctionDefLibrary* flib) { FunctionDefLibrary* flib) {
if (function.getBlocks().size() != 1) { if (function.getBlocks().size() != 1) {
return errors::FailedPrecondition( return errors::FailedPrecondition(
"Input Function must have only one basic block!"); "Input FuncOp must have only one basic block!");
} }
mlir::Block& block = function.front(); mlir::Block& block = function.front();
@ -433,7 +433,7 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(const ExporterConfigs& confs,
mlir::Type type = arg->getType(); mlir::Type type = arg->getType();
if (!type.isa<mlir::TensorType>()) { if (!type.isa<mlir::TensorType>()) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Functions arguments must have tensor types. Found ", "FuncOps arguments must have tensor types. Found ",
mlir::debugString(type), " in function ", function.getName().str()); mlir::debugString(type), " in function ", function.getName().str());
} }
@ -448,7 +448,9 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(const ExporterConfigs& confs,
// definition library // definition library
// TODO(prakalps): If two functions have cyclic dependence, this will // TODO(prakalps): If two functions have cyclic dependence, this will
// introduce an infinite loop. // introduce an infinite loop.
auto func = function.getModule().getNamedFunction(op_name.ValueOrDie()); auto func =
function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
op_name.ValueOrDie());
if (func != nullptr) { if (func != nullptr) {
TF_RETURN_IF_ERROR(ConvertLibFunction(confs, tf_dialect, func, flib)); TF_RETURN_IF_ERROR(ConvertLibFunction(confs, tf_dialect, func, flib));
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib)); TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
@ -483,7 +485,7 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(const ExporterConfigs& confs,
Status Exporter::ConvertLibFunction(const ExporterConfigs& configs, Status Exporter::ConvertLibFunction(const ExporterConfigs& configs,
const Dialect* tf_dialect, const Dialect* tf_dialect,
mlir::Function function, mlir::FuncOp function,
FunctionDefLibrary* flib) { FunctionDefLibrary* flib) {
// First look for the function in the current function library. If found, // First look for the function in the current function library. If found,
// nothing needs to be done. // nothing needs to be done.
@ -510,8 +512,10 @@ Status Exporter::ConvertLibFunction(const ExporterConfigs& configs,
// Checks for gradient attribute. If present converts the gradient function // Checks for gradient attribute. If present converts the gradient function
// and populates the GradientDef. // and populates the GradientDef.
auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName(); auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
if (auto attr = function.getAttrOfType<mlir::FunctionAttr>(grad_string)) { if (auto attr = function.getAttrOfType<mlir::SymbolRefAttr>(grad_string)) {
auto grad_func = function.getModule().getNamedFunction(attr.getValue()); auto grad_func =
function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
attr.getValue());
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ConvertLibFunction(configs, tf_dialect, grad_func, flib)); ConvertLibFunction(configs, tf_dialect, grad_func, flib));
GradientDef grad; GradientDef grad;
@ -531,15 +535,15 @@ Status Exporter::ConvertLibFunction(const ExporterConfigs& configs,
return Status::OK(); return Status::OK();
} }
Status Exporter::Convert(mlir::Module module, const ExporterConfigs& configs, Status Exporter::Convert(mlir::ModuleOp module, const ExporterConfigs& configs,
std::unique_ptr<Graph>* graph, std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def) { FunctionLibraryDefinition* flib_def) {
mlir::Identifier entry_func_id = mlir::Identifier entry_func_id =
mlir::Identifier::get("main", module.getContext()); mlir::Identifier::get("main", module.getContext());
absl::optional<mlir::Function> entry_func; absl::optional<mlir::FuncOp> entry_func;
FunctionDefLibrary flib; FunctionDefLibrary flib;
auto tf_dialect = module.getContext()->getRegisteredDialect("tf"); auto tf_dialect = module.getContext()->getRegisteredDialect("tf");
for (auto function : module.getOps<mlir::Function>()) { for (auto function : module.getOps<mlir::FuncOp>()) {
if (function.isExternal()) if (function.isExternal())
return errors::FailedPrecondition("External functions not supported"); return errors::FailedPrecondition("External functions not supported");
@ -567,14 +571,14 @@ Status Exporter::Convert(mlir::Module module, const ExporterConfigs& configs,
} }
} // namespace } // namespace
Status ConvertMlirToGraph(mlir::Module module, const ExporterConfigs& confs, Status ConvertMlirToGraph(mlir::ModuleOp module, const ExporterConfigs& confs,
std::unique_ptr<Graph>* graph, std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def) { FunctionLibraryDefinition* flib_def) {
return Exporter::Convert(module, confs, graph, flib_def); return Exporter::Convert(module, confs, graph, flib_def);
} }
StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef( StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
mlir::Module module, const ExporterConfigs& confs) { mlir::ModuleOp module, const ExporterConfigs& confs) {
FunctionLibraryDefinition flib_def(OpRegistry::Global(), FunctionLibraryDefinition flib_def(OpRegistry::Global(),
FunctionDefLibrary()); FunctionDefLibrary());
auto graph = absl::make_unique<Graph>(flib_def); auto graph = absl::make_unique<Graph>(flib_def);

View File

@ -32,13 +32,13 @@ using stream_executor::port::StatusOr;
// Given an MLIR module, returns a GraphDef. // Given an MLIR module, returns a GraphDef.
StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef( StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
mlir::Module module, const ExporterConfigs& configs); mlir::ModuleOp module, const ExporterConfigs& configs);
// Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition. // Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition.
// The "main" function of the module is stored in the graph and the rest of // The "main" function of the module is stored in the graph and the rest of
// functions are stored in the library. // functions are stored in the library.
stream_executor::port::Status ConvertMlirToGraph( stream_executor::port::Status ConvertMlirToGraph(
mlir::Module module, const ExporterConfigs& confs, mlir::ModuleOp module, const ExporterConfigs& confs,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def); std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def);
} // namespace tensorflow } // namespace tensorflow

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Identifier.h" // TF:local_config_mlir #include "mlir/IR/Identifier.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir #include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
@ -86,7 +87,7 @@ class Importer {
explicit Importer( explicit Importer(
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
const NodeSpecs& specs, mlir::Module module, const NodeSpecs& specs, mlir::ModuleOp module,
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name) std::unordered_map<std::string, std::string>* tf_name_to_mlir_name)
: module_(module), : module_(module),
context_(module.getContext()), context_(module.getContext()),
@ -158,8 +159,8 @@ class Importer {
return ::tensorflow::ConvertTensorProto(value, builder_.get()); return ::tensorflow::ConvertTensorProto(value, builder_.get());
} }
// Converts func name in graphdef to mlir::FunctionAttribute. // Converts func name in graphdef to mlir::SymbolRefAttribute.
StatusOr<mlir::FunctionAttr> ConvertFunctionCallName( StatusOr<mlir::SymbolRefAttr> ConvertFunctionCallName(
const std::string& func_name); const std::string& func_name);
// Converts the given non-function-call AttrValue to an MLIR Attribute. // Converts the given non-function-call AttrValue to an MLIR Attribute.
@ -262,7 +263,7 @@ class Importer {
using NodeValueMap = absl::flat_hash_map<int, mlir::Operation*>; using NodeValueMap = absl::flat_hash_map<int, mlir::Operation*>;
std::unique_ptr<mlir::OpBuilder> builder_; std::unique_ptr<mlir::OpBuilder> builder_;
mlir::Module module_; mlir::ModuleOp module_;
mlir::MLIRContext* context_; mlir::MLIRContext* context_;
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_; std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
const FunctionLibraryDefinition& graph_flib_; const FunctionLibraryDefinition& graph_flib_;
@ -611,12 +612,12 @@ Status Importer::ConvertFunctionCallAttribute(
return Status::OK(); return Status::OK();
} }
StatusOr<mlir::FunctionAttr> Importer::ConvertFunctionCallName( StatusOr<mlir::SymbolRefAttr> Importer::ConvertFunctionCallName(
const std::string& func_name) { const std::string& func_name) {
TF_RETURN_IF_ERROR(ConvertLibFunction(func_name)); TF_RETURN_IF_ERROR(ConvertLibFunction(func_name));
auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name]; auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name];
auto func = module_.getNamedFunction(mlir_func_name); auto func = module_.lookupSymbol<mlir::FuncOp>(mlir_func_name);
return builder_->getFunctionAttr(func); return builder_->getSymbolRefAttr(func);
} }
StatusOr<mlir::Attribute> Importer::ConvertAttributeValue( StatusOr<mlir::Attribute> Importer::ConvertAttributeValue(
@ -721,8 +722,8 @@ Status Importer::ConvertLibFunction(const std::string& func_name) {
if (!grad_func_name.empty()) { if (!grad_func_name.empty()) {
TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name)); TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name));
auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name]; auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name];
auto grad_func = module_.getNamedFunction(mlir_grad_func_name); auto grad_func = module_.lookupSymbol<mlir::FuncOp>(mlir_grad_func_name);
auto gradient_attr = builder_->getFunctionAttr(grad_func); auto gradient_attr = builder_->getSymbolRefAttr(grad_func);
auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName(); auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
attributes.push_back(builder_->getNamedAttr(grad_string, gradient_attr)); attributes.push_back(builder_->getNamedAttr(grad_string, gradient_attr));
} }
@ -1151,7 +1152,7 @@ Status Importer::Convert(llvm::StringRef func_name,
const absl::InlinedVector<OutputTensor, 4>& ret_nodes, const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
llvm::ArrayRef<mlir::NamedAttribute> attrs) { llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// TODO(b/122040776): Uses debug info for FunctionDef. // TODO(b/122040776): Uses debug info for FunctionDef.
auto function = mlir::Function::create(mlir::UnknownLoc::get(context_), auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
func_name, func_type, attrs); func_name, func_type, attrs);
module_.push_back(function); module_.push_back(function);
@ -1291,7 +1292,8 @@ StatusOr<mlir::OwningModuleRef> Importer::Convert(
mlir::MLIRContext* context, const Graph& graph, mlir::MLIRContext* context, const Graph& graph,
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
const NodeSpecs& specs) { const NodeSpecs& specs) {
mlir::OwningModuleRef module = mlir::Module::create(context); mlir::OwningModuleRef module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
std::unordered_map<std::string, std::string> tf_name_to_mlir_name; std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
Importer importer(flib_def, debug_info, specs, module.get(), Importer importer(flib_def, debug_info, specs, module.get(),
&tf_name_to_mlir_name); &tf_name_to_mlir_name);

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_ #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/Support/SourceMgr.h" #include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Identifier.h" // TF:local_config_mlir #include "mlir/IR/Identifier.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir
@ -119,7 +120,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
break; break;
default: default:
inst.emitWarning() inst.emitWarning()
<< "Skipping splat converstion for " << "Skipping splat conversion for "
<< "an unsupported attribute type " << element_type; << "an unsupported attribute type " << element_type;
continue; continue;
} }

View File

@ -63,7 +63,7 @@ static TranslateToMLIRRegistration GraphdefToSplattedMlirTranslate(
"graphdef-to-splatted-mlir", GraphdefToSplattedMlirTranslateFunction); "graphdef-to-splatted-mlir", GraphdefToSplattedMlirTranslateFunction);
static LogicalResult MlirToGraphdefTranslateFunction( static LogicalResult MlirToGraphdefTranslateFunction(
Module module, llvm::StringRef output_filename) { ModuleOp module, llvm::StringRef output_filename) {
if (!module) return failure(); if (!module) return failure();
std::error_code error; std::error_code error;

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir #include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir
@ -22,8 +23,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
namespace mlir { namespace mlir {
static mlir::Operation* ExtractOnlyOp(mlir::Module module) { static mlir::Operation* ExtractOnlyOp(mlir::ModuleOp module) {
mlir::Function fn = module.getNamedFunction("main"); mlir::FuncOp fn = module.lookupSymbol<mlir::FuncOp>("main");
if (!fn) return nullptr; if (!fn) return nullptr;
if (fn.getBlocks().size() != 1) return nullptr; if (fn.getBlocks().size() != 1) return nullptr;
@ -38,7 +39,8 @@ static mlir::Operation* ExtractOnlyOp(mlir::Module module) {
return &block.front(); return &block.front();
} }
static LogicalResult MlirToTfNodeDef(Module module, llvm::StringRef filename) { static LogicalResult MlirToTfNodeDef(ModuleOp module,
llvm::StringRef filename) {
auto* context = module.getContext(); auto* context = module.getContext();
auto file = openOutputFile(filename); auto file = openOutputFile(filename);

View File

@ -46,19 +46,15 @@ Status ConvertDataType(const DataType& dtype, Builder builder, Type* type) {
*type = builder.getIntegerType(1); *type = builder.getIntegerType(1);
return Status::OK(); return Status::OK();
case DT_INT8: case DT_INT8:
case DT_UINT8:
*type = builder.getIntegerType(8); *type = builder.getIntegerType(8);
return Status::OK(); return Status::OK();
case DT_INT16: case DT_INT16:
case DT_UINT16:
*type = builder.getIntegerType(16); *type = builder.getIntegerType(16);
return Status::OK(); return Status::OK();
case DT_INT32: case DT_INT32:
case DT_UINT32:
*type = builder.getIntegerType(32); *type = builder.getIntegerType(32);
return Status::OK(); return Status::OK();
case DT_INT64: case DT_INT64:
case DT_UINT64:
*type = builder.getIntegerType(64); *type = builder.getIntegerType(64);
return Status::OK(); return Status::OK();
case DT_BFLOAT16: case DT_BFLOAT16:

View File

@ -115,7 +115,7 @@ Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) {
return Status::OK(); return Status::OK();
} }
Status ConvertAttribute(const mlir::FunctionAttr& attr, AttrValue* value) { Status ConvertAttribute(const mlir::SymbolRefAttr& attr, AttrValue* value) {
value->mutable_func()->set_name(attr.getValue()); value->mutable_func()->set_name(attr.getValue());
return Status::OK(); return Status::OK();
} }
@ -149,7 +149,7 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) {
TensorProto tensor; TensorProto tensor;
TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor)); TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor));
*list->add_tensor() = tensor; *list->add_tensor() = tensor;
} else if (auto attr = a.dyn_cast<mlir::FunctionAttr>()) { } else if (auto attr = a.dyn_cast<mlir::SymbolRefAttr>()) {
AttrValue attrVal; AttrValue attrVal;
TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attrVal)); TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attrVal));
*list->add_func() = attrVal.func(); *list->add_func() = attrVal.func();
@ -218,8 +218,8 @@ Status ConvertAttributes(const llvm::ArrayRef<mlir::NamedAttribute> attrs,
} }
AttrValue value; AttrValue value;
switch (attr.getKind()) { switch (attr.getKind()) {
case mlir::StandardAttributes::Function: { case mlir::StandardAttributes::SymbolRef: {
auto func_attr = attr.cast<mlir::FunctionAttr>(); auto func_attr = attr.cast<mlir::SymbolRefAttr>();
value.mutable_func()->set_name(func_attr.getValue()); value.mutable_func()->set_name(func_attr.getValue());
func_call_attrs[string(name)] = value; func_call_attrs[string(name)] = value;
continue; continue;

View File

@ -298,8 +298,8 @@ def XLA_WhileOp: XLA_Op<"while", [NoSideEffect, SameOperandsAndResultType]> {
let arguments = (ins let arguments = (ins
Variadic<XLA_TensorOrTuple>:$val, Variadic<XLA_TensorOrTuple>:$val,
FunctionAttr:$cond, SymbolRefAttr:$cond,
FunctionAttr:$body SymbolRefAttr:$body
); );
let results = (outs Variadic<XLA_TensorOrTuple>:$res); let results = (outs Variadic<XLA_TensorOrTuple>:$res);
@ -320,7 +320,7 @@ def XLA_ReduceOp: XLA_Op<"reduce", [NoSideEffect]> {
let arguments = (ins let arguments = (ins
Variadic<XLA_TensorOrTuple>:$operands_and_init, Variadic<XLA_TensorOrTuple>:$operands_and_init,
FunctionAttr:$computation, SymbolRefAttr:$computation,
ElementsAttr:$dimensions ElementsAttr:$dimensions
); );

View File

@ -116,12 +116,14 @@ class RandomOpsTest(xla_test.XLATestCase):
def rng(dtype): def rng(dtype):
return random_ops.truncated_normal(shape=[2], dtype=dtype) return random_ops.truncated_normal(shape=[2], dtype=dtype)
self._testRngIsNotConstant(rng, dtypes.float32) # TODO(b/34339814): make this test work with 16 bit float types.
for dtype in self._random_types() & {np.float32, np.float64}:
self._testRngIsNotConstant(rng, dtype)
def testTruncatedNormalIsInRange(self): def testTruncatedNormalIsInRange(self):
count = 10000000 count = 10000000
# TODO(b/34339814): make this test work with 16 bit float types. # TODO(b/34339814): make this test work with 16 bit float types.
for dtype in self._random_types() & {dtypes.float32, dtypes.float64}: for dtype in self._random_types() & {np.float32, np.float64}:
with self.session() as sess: with self.session() as sess:
with self.test_scope(): with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype) x = random_ops.truncated_normal(shape=[count], dtype=dtype)

View File

@ -293,7 +293,7 @@ class TruncatedNormalOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("TruncatedNormal") REGISTER_XLA_OP(Name("TruncatedNormal")
.CompileTimeConstantInput("shape") .CompileTimeConstantInput("shape")
.TypeConstraint("dtype", DT_FLOAT), .TypeConstraint("dtype", {DT_FLOAT, DT_DOUBLE}),
TruncatedNormalOp); TruncatedNormalOp);
} // namespace } // namespace

View File

@ -42,63 +42,199 @@ const char kRetvalOp[] = "_Retval";
const int kMaxCallDepth = 100; const int kMaxCallDepth = 100;
Status AnalyzeResourceUsage(
const Graph* graph, const absl::optional<std::string>& function_name,
const int call_depth, const absl::flat_hash_set<int>& resource_arg_indices,
FunctionLibraryRuntime* lib_runtime,
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
source_to_path);
bool IsControlFlowV1Node(const Node* n) { bool IsControlFlowV1Node(const Node* n) {
return (n->IsEnter() || n->IsExit() || n->IsSwitch() || n->IsMerge() || return (n->IsEnter() || n->IsExit() || n->IsSwitch() || n->IsMerge() ||
n->IsNextIteration()); n->IsNextIteration());
} }
// Given an output edge, find the corresponding input edge if given edge is
// coming from a pass-through node. Otherwise, return nullptr.
StatusOr<const Edge*> WalkBackPassThroughEdge(const Edge* e) {
const Node* n = e->src();
if (n->IsIdentity()) {
const Edge* ret;
TF_RETURN_IF_ERROR(n->input_edge(0, &ret));
return ret;
}
if (n->type_string() == kIdentityNOp) {
const Edge* ret;
TF_RETURN_IF_ERROR(n->input_edge(e->src_output(), &ret));
return ret;
}
// Reaching here means e is not coming from a pass through node, return empty
// vector to indicate we can no longer trace back.
return nullptr;
}
// TODO(ycao): Add this as Tensorflow Node method. // TODO(ycao): Add this as Tensorflow Node method.
StatusOr<absl::InlinedVector<const Edge*, 1>> OutputEdgesByIndex(const Node* n, StatusOr<absl::InlinedVector<const Edge*, 1>> OutputEdgesByIndex(const Node& n,
int idx) { int idx) {
absl::InlinedVector<const Edge*, 1> res; absl::InlinedVector<const Edge*, 1> res;
if (idx >= n->num_outputs()) { if (idx >= n.num_outputs()) {
return errors::InvalidArgument("Invalid out_edge index: ", idx, ", Node ", return errors::InvalidArgument("Invalid out_edge index: ", idx, ", Node ",
n->name(), " only has ", n->num_outputs(), n.name(), " only has ", n.num_outputs(),
" outputs."); " outputs.");
} }
for (const Edge* o : n->out_edges()) { for (const Edge* o : n.out_edges()) {
if (o->src_output() == idx) res.emplace_back(o); if (o->src_output() == idx) res.emplace_back(o);
} }
return res; return res;
} }
bool IsStackOrTensorArraySource(const Node* n) { bool IsStackOrTensorArraySource(const Node& n) {
const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n->type_string()); const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string());
if (!op_info) return false; if (!op_info) return false;
if (op_info->resource_kind() != XlaResourceKind::kStack && if (op_info->resource_kind() != XlaResourceKind::kStack &&
op_info->resource_kind() != XlaResourceKind::kTensorArray) op_info->resource_kind() != XlaResourceKind::kTensorArray)
return false; return false;
return n->num_outputs() > 0 && n->output_type(0) == DataType::DT_RESOURCE; return n.num_outputs() > 0 && n.output_type(0) == DataType::DT_RESOURCE;
}
void PropagateFromStackOrTensorArraySourceOp(
const Node& n, const absl::optional<std::string>& function_name,
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
user_to_source) {
ResourceUsageAnalysis::NodeInfo src_node_info(function_name, n.name(),
n.type_string());
for (const Edge* o : n.out_edges()) {
if (o->IsControlEdge()) continue;
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
continue;
}
(*user_to_source)[o] = src_node_info;
}
}
Status PropagateFromArgOp(
const Node& n, const absl::optional<std::string>& function_name,
const absl::flat_hash_set<int>& resource_arg_indices,
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
user_to_source) {
TF_RET_CHECK(n.type_string() == kArgOp);
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", &index));
if (!resource_arg_indices.contains(index)) return Status::OK();
TF_RET_CHECK(function_name.has_value())
<< "ResourceUsageAnalysis does not support analyzing _Arg nodes "
"carrying Stack/TensorArray resource in given graph unless they "
"are in function calls.";
const ResourceUsageAnalysis::NodeInfo src_node_info(function_name, n.name(),
n.type_string());
for (const Edge* o : n.out_edges()) {
if (o->IsControlEdge()) continue;
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
continue;
}
(*user_to_source)[o] = src_node_info;
}
return Status::OK();
}
Status PropagateThroughCallOp(
const Node& n, const absl::optional<std::string>& function_name,
const int call_depth, FunctionLibraryRuntime* lib_runtime,
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
user_to_source,
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
source_to_path) {
if (call_depth > kMaxCallDepth) {
return errors::InvalidArgument(
"Function call stack in given graph is too deep, last function ",
"name is: ", function_name.value());
}
// resource_arg_indices_for_call contains all indices of the input
// arguments that carry Stack/TensorArray resource handles.
absl::flat_hash_set<int> resource_arg_indices_for_call;
for (const Edge* e : n.in_edges()) {
if (!user_to_source->contains(e)) continue;
resource_arg_indices_for_call.emplace(e->dst_input());
}
absl::string_view called_function_name = n.type_string();
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(InstantiateFunctionCall(n.def(), lib_runtime, &handle));
auto release_handle_on_return = gtl::MakeCleanup(
[&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
// Recursively analyze called function for resource sources and users.
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>
called_function_source_to_path;
TF_RETURN_IF_ERROR(AnalyzeResourceUsage(
fbody->graph, absl::optional<std::string>(called_function_name),
call_depth + 1, resource_arg_indices_for_call, lib_runtime,
&called_function_source_to_path));
std::unordered_map<std::string, Node*> node_name_index =
fbody->graph->BuildNodeNameIndex();
for (auto it : called_function_source_to_path) {
ResourceUsageAnalysis::NodeInfo src_node_info = it.first;
// If source is an _Arg, then the true source is actually corresponding
// edge that feeds into function call node with the same index.
if (src_node_info.op_ == kArgOp) {
const Node* arg_src = node_name_index[src_node_info.node_name_];
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(arg_src->attrs(), "index", &index));
const Edge* e;
TF_RETURN_IF_ERROR(n.input_edge(index, &e));
const Node* true_src = e->src();
src_node_info.function_name_ = function_name;
src_node_info.node_name_ = true_src->name();
src_node_info.op_ = true_src->type_string();
}
for (const auto& dst_node_info : it.second) {
// If user is an _Retval, then the true user is actually corresponding
// edge of that _Retval.
if (dst_node_info.op_ == kRetvalOp) {
const Node* ret_user = node_name_index[dst_node_info.node_name_];
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(ret_user->attrs(), "index", &index));
absl::InlinedVector<const Edge*, 1> outs;
TF_ASSIGN_OR_RETURN(outs, OutputEdgesByIndex(n, index));
for (const Edge* o : outs) (*user_to_source)[o] = src_node_info;
} else {
(*source_to_path)[src_node_info].emplace(dst_node_info);
}
}
}
return Status::OK();
}
// Analyzes pass through values for Identity and IdentityN ops.
Status PropagateThroughIdentityOp(
const Node& n,
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
user_to_source) {
TF_RET_CHECK(n.IsIdentity() || n.type_string() == kIdentityNOp);
if (n.IsIdentity()) {
for (const Edge* o : n.out_edges()) {
if (o->IsControlEdge()) continue;
const Edge* in;
TF_RETURN_IF_ERROR(n.input_edge(0, &in));
if (!user_to_source->contains(in)) continue;
user_to_source->emplace(std::make_pair(o, (*user_to_source)[in]));
}
} else {
for (const Edge* o : n.out_edges()) {
if (o->IsControlEdge()) continue;
const Edge* in;
TF_RETURN_IF_ERROR(n.input_edge(o->src_output(), &in));
if (!user_to_source->contains(in)) continue;
user_to_source->emplace(std::make_pair(o, (*user_to_source)[in]));
}
}
return Status::OK();
} }
Status AnalyzeResourceUsage( Status AnalyzeResourceUsage(
const Graph* graph, FunctionLibraryRuntime* lib_runtime, const Graph* graph, const absl::optional<std::string>& function_name,
const absl::optional<std::string>& function_name, const int call_depth, const int call_depth, const absl::flat_hash_set<int>& resource_arg_indices,
const absl::flat_hash_set<int>& resource_arg_indices, FunctionLibraryRuntime* lib_runtime,
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo, absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>* absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
source_to_path) { source_to_path) {
@ -127,120 +263,30 @@ Status AnalyzeResourceUsage(
} }
// Record a resource source edge. // Record a resource source edge.
if (IsStackOrTensorArraySource(n)) { if (IsStackOrTensorArraySource(*n)) {
ResourceUsageAnalysis::NodeInfo src_node_info(function_name, n->name(), PropagateFromStackOrTensorArraySourceOp(*n, function_name,
n->type_string()); &user_to_source);
for (const Edge* o : n->out_edges()) {
if (o->IsControlEdge()) continue;
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
continue;
}
user_to_source[o] = src_node_info;
}
continue; continue;
} }
// Arguments that are listed in resource_arg_indices are also considered as // Arguments that are listed in resource_arg_indices are also considered as
// resource sources. // resource sources.
if (n->IsArg()) { if (n->IsArg()) {
int index; TF_RETURN_IF_ERROR(PropagateFromArgOp(
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); *n, function_name, resource_arg_indices, &user_to_source));
if (!resource_arg_indices.contains(index)) continue;
TF_RET_CHECK(function_name.has_value())
<< "ResourceUsageAnalysis does not support analyzing _Arg nodes "
"carrying Stack/TensorArray resource in given graph unless they "
"are in function calls.";
const ResourceUsageAnalysis::NodeInfo src_node_info(
function_name, n->name(), n->type_string());
for (const Edge* o : n->out_edges()) {
if (o->IsControlEdge()) continue;
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
continue;
}
user_to_source[o] = src_node_info;
}
continue; continue;
} }
// Recursively analyze function call ops.
if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), *n)) { if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), *n)) {
if (call_depth > kMaxCallDepth) { TF_RETURN_IF_ERROR(PropagateThroughCallOp(*n, function_name, call_depth,
return errors::InvalidArgument( lib_runtime, &user_to_source,
"Function call stack in given graph is too deep, last function ", source_to_path));
"name is: ", function_name.value());
}
// resource_arg_indices_for_call contains all indices of the input
// arguments that carry Stack/TensorArray resource handles.
absl::flat_hash_set<int> resource_arg_indices_for_call;
for (const Edge* e : n->in_edges()) {
if (!user_to_source.contains(e)) continue;
resource_arg_indices_for_call.emplace(e->dst_input());
}
absl::string_view called_function_name = n->type_string();
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(
InstantiateFunctionCall(n->def(), lib_runtime, &handle));
auto release_handle_on_return = gtl::MakeCleanup(
[&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
// Recursively analyze called function for resource sources and users.
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>
called_function_source_to_path;
TF_RETURN_IF_ERROR(AnalyzeResourceUsage(
fbody->graph, lib_runtime,
absl::optional<std::string>(called_function_name), call_depth + 1,
resource_arg_indices_for_call, &called_function_source_to_path));
std::unordered_map<std::string, Node*> node_name_index =
fbody->graph->BuildNodeNameIndex();
for (auto it : called_function_source_to_path) {
ResourceUsageAnalysis::NodeInfo src_node_info = it.first;
// If source is an _Arg, then the true source is actually corresponding
// edge that feeds into function call node with the same index.
if (src_node_info.op_ == kArgOp) {
const Node* arg_src = node_name_index[src_node_info.node_name_];
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(arg_src->attrs(), "index", &index));
const Edge* e;
TF_RETURN_IF_ERROR(n->input_edge(index, &e));
const Node* true_src = e->src();
src_node_info.function_name_ = function_name;
src_node_info.node_name_ = true_src->name();
src_node_info.op_ = true_src->type_string();
}
for (const auto& dst_node_info : it.second) {
// If user is an _Retval, then the true user is actually corresponding
// edge of that _Retval.
if (dst_node_info.op_ == kRetvalOp) {
const Node* ret_user = node_name_index[dst_node_info.node_name_];
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(ret_user->attrs(), "index", &index));
absl::InlinedVector<const Edge*, 1> outs;
TF_ASSIGN_OR_RETURN(outs, OutputEdgesByIndex(n, index));
for (const Edge* o : outs) user_to_source[o] = src_node_info;
} else {
(*source_to_path)[src_node_info].emplace(dst_node_info);
}
}
}
continue; continue;
} }
for (const Edge* o : n->out_edges()) { if (n->IsIdentity() || n->type_string() == kIdentityNOp) {
if (o->IsControlEdge()) continue; TF_RETURN_IF_ERROR(PropagateThroughIdentityOp(*n, &user_to_source));
TF_ASSIGN_OR_RETURN(const Edge* e, WalkBackPassThroughEdge(o));
if (!e || !user_to_source.contains(e)) continue;
user_to_source.emplace(std::make_pair(o, user_to_source[e]));
} }
} }
@ -260,8 +306,9 @@ Status AnalyzeResourceUsage(
absl::flat_hash_map<NodeInfo, absl::flat_hash_set<NodeInfo>>* absl::flat_hash_map<NodeInfo, absl::flat_hash_set<NodeInfo>>*
source_to_path) { source_to_path) {
return AnalyzeResourceUsage( return AnalyzeResourceUsage(
graph, lib_runtime, /*function_name=*/{}, /*call_depth=*/0, graph, /*function_name=*/{}, /*call_depth=*/0,
/*resource_arg_indices=*/absl::flat_hash_set<int>(), source_to_path); /*resource_arg_indices=*/absl::flat_hash_set<int>(), lib_runtime,
source_to_path);
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -1367,12 +1367,12 @@ For a more intuitive description, see the "Informal Description" section below.
: : : detailed description. : : : : detailed description. :
| `offset_dims` | `ArraySlice<int64>` | The set of dimensions in the | | `offset_dims` | `ArraySlice<int64>` | The set of dimensions in the |
: : : output shape that offset into : : : : output shape that offset into :
: : : a array sliced from operand. : : : : an array sliced from operand. :
| `slice_sizes` | `ArraySlice<int64>` | `slice_sizes[i]` is the | | `slice_sizes` | `ArraySlice<int64>` | `slice_sizes[i]` is the |
: : : bounds for the slice on : : : : bounds for the slice on :
: : : dimension `i`. : : : : dimension `i`. :
| `collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each | | `collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each |
: : : \: slice that are collapsed : : : : slice that are collapsed :
: : : away. These dimensions must : : : : away. These dimensions must :
: : : have size 1. : : : : have size 1. :
| `start_index_map` | `ArraySlice<int64>` | A map that describes how to | | `start_index_map` | `ArraySlice<int64>` | A map that describes how to |
@ -1383,8 +1383,11 @@ For a more intuitive description, see the "Informal Description" section below.
For convenience, we label dimensions in the output array not in `offset_dims` For convenience, we label dimensions in the output array not in `offset_dims`
as `batch_dims`. as `batch_dims`.
The output is an array of rank `batch_dims.size` + `operand.rank` - The output is an array of rank `batch_dims.size` + `offset_dims.size`.
`collapsed_slice_dims`.size.
The `operand.rank` must equal the sume of `offset_dims.size` and
`collapsed_slice_dims`. Also, `slice_sizes.size` has to be equal to
`operand.rank`.
If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider
`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of `start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of
@ -1405,35 +1408,40 @@ accounting for `collapsed_slice_dims` (i.e. we pick
`adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes` `adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes`
with the bounds at indices `collapsed_slice_dims` removed). with the bounds at indices `collapsed_slice_dims` removed).
Formally, the operand index `In` corresponding to an output index `Out` is Formally, the operand index `In` corresponding to a given output index `Out` is
computed as follows: calculated as follows:
1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out 1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out a
vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where
Combine(A, b) inserts b at position `index_vector_dim` into A. Note that Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
this is well defined even if `G` is empty -- if `G` is empty then `S` = this is well defined even if `G` is empty -- if `G` is empty then `S` =
`start_indices`. `start_indices`.
2. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by 2. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by
scattering `S` using `start_index_map`. More precisely: scattering `S` using `start_index_map`. More precisely:
1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
`start_index_map.size`. 1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
2. `S`<sub>`in`</sub>[`_`] = `0` otherwise. `start_index_map.size`.
2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
3. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices 3. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices
at the offset dimensions in `Out` according to the `collapsed_slice_dims` at the offset dimensions in `Out` according to the `collapsed_slice_dims`
set. More precisely: set. More precisely:
1. `O`<sub>`in`</sub>[`expand_offset_dims`(`k`)] =
`Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
(`expand_offset_dims` is defined below).
2. `O`<sub>`in`</sub>[`_`] = `0` otherwise.
4. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
addition.
`expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`) 1. `O`<sub>`in`</sub>[`remapped_offset_dims`(`k`)] =
`Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
(`remapped_offset_dims` is defined below).
2. `O`<sub>`in`</sub>[`_`] = `0` otherwise.
4. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
addition.
`remapped_offset_dims` is a monotonic function with domain [`0`, `offset.size`)
and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g., and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g.,
`offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`, `offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`,
`2`} then `expand_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}. `2`} then `remapped_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
### Informal Description and Examples ### Informal Description and Examples
@ -1441,25 +1449,24 @@ Informally, every index `Out` in the output array corresponds to an element `E`
in the operand array, computed as follows: in the operand array, computed as follows:
- We use the batch dimensions in `Out` to look up a starting index from - We use the batch dimensions in `Out` to look up a starting index from
`start_indices`. `start_indices`.
- We use `start_index_map` to map the starting index (which may have size less - We use `start_index_map` to map the starting index (whose size may be less
than operand.rank) to a "full" starting index into operand. than operand.rank) to a "full" starting index into the `operand`.
- We dynamic-slice out a slice with size `slice_sizes` using the full starting - We dynamic-slice out a slice with size `slice_sizes` using the full starting
index. index.
- We reshape the slice by collapsing the `collapsed_slice_dims` dimensions. - We reshape the slice by collapsing the `collapsed_slice_dims` dimensions.
Since all collapsed slice dimensions have to have bound 1 this reshape is Since all collapsed slice dimensions must have a bound of 1, this reshape is
always legal. always legal.
- We use the offset dimensions in `Out` to index into this slice to get the - We use the offset dimensions in `Out` to index into this slice to get the
input element, `E`, corresponding to output index `Out`. input element, `E`, corresponding to output index `Out`.
`index_vector_dim` is set to `start_indices.rank` - `1` in all of the `index_vector_dim` is set to `start_indices.rank` - `1` in all of the examples
examples that follow. More interesting values for `index_vector_dim` does not that follow. More interesting values for `index_vector_dim` do not change the
change the operation fundamentally, but makes the visual representation more operation fundamentally, but make the visual representation more cumbersome.
cumbersome.
To get an intuition on how all of the above fits together, let's look at an To get an intuition on how all of the above fits together, let's look at an
example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array. The example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array. The
@ -1529,9 +1536,9 @@ from the gather indices array as usual, except the starting index has only one
element, `X`. Similarly, there is only one output offset index with the value element, `X`. Similarly, there is only one output offset index with the value
`O`<sub>`0`</sub>. However, before being used as indices into the input array, `O`<sub>`0`</sub>. However, before being used as indices into the input array,
these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in
the formal description) and "Offset Mapping" (`expand_offset_dims` in the formal the formal description) and "Offset Mapping" (`remapped_offset_dims` in the
description) into [`X`,`0`] and [`0`,`O`<sub>`0`</sub>] respectively, adding up formal description) into [`X`,`0`] and [`0`,`O`<sub>`0`</sub>] respectively,
to [`X`,`O`<sub>`0`</sub>]. In other words, the output index adding up to [`X`,`O`<sub>`0`</sub>]. In other words, the output index
[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index [`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index
[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us [`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us
the semantics for `tf.gather_nd`. the semantics for `tf.gather_nd`.

View File

@ -104,11 +104,49 @@ cc_library(
], ],
) )
cc_library(
name = "event_pool",
srcs = ["event_pool.cc"],
hdrs = ["event_pool.h"],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:stream_executor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
],
)
cc_library(
name = "semaphore",
srcs = ["semaphore.cc"],
hdrs = ["semaphore.h"],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"@com_google_absl//absl/synchronization",
],
)
tf_cc_test(
name = "semaphore_test",
srcs = ["semaphore_test.cc"],
deps = [
":semaphore",
"//tensorflow/compiler/xla:test",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
"@com_google_absl//absl/synchronization",
],
)
cc_library( cc_library(
name = "shared_device_buffer", name = "shared_device_buffer",
srcs = ["shared_device_buffer.cc"], srcs = ["shared_device_buffer.cc"],
hdrs = ["shared_device_buffer.h"], hdrs = ["shared_device_buffer.h"],
deps = [ deps = [
":event_pool",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager",
@ -132,6 +170,23 @@ tf_cc_test(
], ],
) )
cc_library(
name = "device",
srcs = ["device.cc"],
hdrs = ["device.h"],
deps = [
":event_pool",
":semaphore",
":worker_thread",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
],
)
cc_library( cc_library(
name = "local_client", name = "local_client",
srcs = [ srcs = [
@ -149,9 +204,9 @@ cc_library(
], ],
features = ["-use_header_modules"], features = ["-use_header_modules"],
deps = [ deps = [
":device",
":shared_device_buffer", ":shared_device_buffer",
":types", ":types",
":worker_thread",
"//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:literal_util",
@ -165,7 +220,6 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:custom_call_target_registry",
"//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:bfc_allocator", "//tensorflow/core:bfc_allocator",
@ -199,10 +253,11 @@ tf_pybind_extension(
":local_client", ":local_client",
":shared_device_buffer", ":shared_device_buffer",
":types", ":types",
":worker_thread",
":xrt", ":xrt",
"@com_google_absl//absl/base",
"@com_google_absl//absl/hash", "@com_google_absl//absl/hash",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
@ -225,6 +280,7 @@ tf_pybind_extension(
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig", "//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
"//tensorflow/compiler/xla/client/lib:svd", "//tensorflow/compiler/xla/client/lib:svd",
"//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:custom_call_target_registry",
"//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_graph_dumper",

View File

@ -0,0 +1,100 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/device.h"
#include <memory>
#include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
Device::Device(se::StreamExecutor* executor, bool synchronous_deallocation,
bool asynchronous, bool allow_event_reuse)
: synchronous_deallocation_(synchronous_deallocation),
event_pool_(allow_event_reuse),
compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1) {
compute_stream_ = absl::make_unique<se::Stream>(executor);
host_to_device_stream_ = absl::make_unique<se::Stream>(executor);
device_to_host_stream_ = absl::make_unique<se::Stream>(executor);
callback_stream_ = absl::make_unique<se::Stream>(executor);
compute_stream_->Init();
host_to_device_stream_->Init();
device_to_host_stream_->Init();
callback_stream_->Init();
device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
auto stream = absl::make_unique<se::Stream>(executor);
stream->Init();
device_to_device_streams_.push_back(std::move(stream));
}
execute_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
"py_xla_execute");
callback_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
"py_xla_callback");
}
Device::~Device() {
Status status = SynchronizeAllActivity();
if (!status.ok()) {
LOG(ERROR) << "Error when closing device: " << status;
}
}
Status Device::SynchronizeAllActivity() {
Status status;
// TODO(phawkins): in theory the call to SynchronizeAllActivity below should
// suffice. However on the Host platform SynchronizeAllActivity is a dummy
// implementation that doesn't actually block. To make sure activity has
// stopped, also block on the compute stream. If SynchronizeAllActivity is
// fixed, we could remove the BlockHostUntilDone call.
status.Update(compute_stream_->BlockHostUntilDone());
bool ok = compute_stream_->parent()->SynchronizeAllActivity();
if (!ok) {
status.Update(Unknown("SynchronizeAllActivity failed."));
}
return status;
}
Status Device::ThenMemcpyDeviceToDevice(se::Stream* src_stream,
se::Stream* dst_stream,
se::DeviceMemoryBase src_buffer,
se::DeviceMemoryBase dst_buffer) {
// The default implementation simply calls ThenMemcpyD2D, and assumes that
// the buffer addresses identify the devices. This does not work
// on all platforms; this method is virtual so it can be overridden.
src_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size());
return Status::OK();
}
void Device::ThenExecuteOnCallbackThread(se::Stream* stream,
std::function<void()> callback) const {
stream->ThenDoHostCallback([this, callback]() mutable {
callback_thread_->Schedule(std::move(callback));
});
}
se::Stream* Device::GetDeviceToDeviceStream() {
absl::MutexLock lock(&mu_);
int i = next_device_to_device_stream_;
next_device_to_device_stream_ =
(next_device_to_device_stream_ + 1) % device_to_device_streams_.size();
return device_to_device_streams_.at(i).get();
}
} // namespace xla

View File

@ -0,0 +1,134 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_H_
#include <memory>
#include <vector>
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/python/event_pool.h"
#include "tensorflow/compiler/xla/python/semaphore.h"
#include "tensorflow/compiler/xla/python/worker_thread.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/core/platform/stream_executor.h"
namespace xla {
// Class that encapsulates state relating to a device (e.g., a GPU) on which we
// can perform computation and transfers.
class Device {
public:
// If synchronous_deallocation is true, the host must not free buffers until
// compute/transfers that use those buffers have completed. For example, this
// typically is the case for the "platform" where compute/transfers are
// operations that take place on another thread.
//
// If asynchronous is false, the host will synchronize to the device after
// each execution or transfer. This is intended for debugging only.
Device(se::StreamExecutor* executor, bool synchronous_deallocation,
bool asynchronous, bool allow_event_reuse);
virtual ~Device();
bool synchronous_deallocation() const { return synchronous_deallocation_; }
EventPool& event_pool() { return event_pool_; }
se::Stream* compute_stream() const { return compute_stream_.get(); }
se::Stream* host_to_device_stream() const {
return host_to_device_stream_.get();
}
se::Stream* device_to_host_stream() const {
return device_to_host_stream_.get();
}
// Returns a device to device stream. Allocates streams in a round-robin
// fashion amongst the available streams.
se::Stream* GetDeviceToDeviceStream();
// Enqueues a copy of `src_buffer` to `dst_buffer` onto `src_stream`.
virtual Status ThenMemcpyDeviceToDevice(se::Stream* src_stream,
se::Stream* dst_stream,
se::DeviceMemoryBase src_buffer,
se::DeviceMemoryBase dst_buffer);
WorkerThread* execute_thread() const { return execute_thread_.get(); }
// Enqueues a host callback on 'stream', to be executed by callback_thread_.
// ThenDoHostCallback is often constrained in what it can do, in particular,
// on GPU the callback runs on a thread belonging to the GPU runtime and
// cannot perform GPU operations itself.
void ThenExecuteOnCallbackThread(se::Stream* stream,
std::function<void()> callback) const;
// Helpers for releasing values on a worker thread at the tail of a stream on
// a worker thread. Copies `object`, and destroys the copy when the tail of
// the stream is reached. The destruction happens either in the caller's
// thread or on the worker thread (depending on thread schedules), not a
// device callback, so it is safe if the destructor frees device resource
// (e.g., GPU objects).
// TODO(phawkins): use move-capture when we can use C++14 features.
template <typename T>
void ThenRelease(se::Stream* stream, T object) const {
if (callback_stream_.get() != stream) {
callback_stream_->ThenWaitFor(stream);
}
ThenExecuteOnCallbackThread(callback_stream_.get(),
[object]() { /* releases object */ });
}
Semaphore& compute_semaphore() { return compute_semaphore_; }
private:
Status SynchronizeAllActivity();
bool synchronous_deallocation_;
EventPool event_pool_;
// Semaphore used to limit how many programs can be enqueued on the compute
// stream by the host ahead of the device.
Semaphore compute_semaphore_;
std::unique_ptr<se::Stream> compute_stream_;
std::unique_ptr<se::Stream> host_to_device_stream_;
std::unique_ptr<se::Stream> device_to_host_stream_;
std::vector<std::unique_ptr<se::Stream>> device_to_device_streams_;
// Number of device-to-device streams to create in the multistream case.
static constexpr int kNumDeviceToDeviceStreams = 4;
absl::Mutex mu_;
int next_device_to_device_stream_ GUARDED_BY(mu_) = 0;
// Callback stream is used for running short host-side callbacks after device
// side events, without preventing the device-side stream from doing useful
// work.
std::unique_ptr<se::Stream> callback_stream_;
// A worker thread, used for replicated computation launches.
std::unique_ptr<WorkerThread> execute_thread_;
// A worker thread, used for callbacks. It is necessary that this be a
// different thread to the execute thread because we acquire the compute
// semaphore during calls to Execute but release it from a callback and if
// they are the same thread we might deadlock.
std::unique_ptr<WorkerThread> callback_thread_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_H_

View File

@ -0,0 +1,52 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/event_pool.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace xla {
EventPool::Handle::~Handle() {
if (pool_ && event_) {
absl::MutexLock lock(&pool_->mu_);
pool_->free_events_.push(std::move(event_));
}
}
EventPool::EventPool(bool allow_reuse) : allow_reuse_(allow_reuse) {}
StatusOr<EventPool::Handle> EventPool::ThenAllocateAndRecordEvent(
se::Stream* stream) {
Handle event;
if (allow_reuse_) {
event.pool_ = this;
absl::MutexLock lock(&mu_);
if (!free_events_.empty()) {
event.event_ = std::move(free_events_.top());
free_events_.pop();
}
}
if (!event.event_) {
event.event_ = absl::make_unique<se::Event>(stream->parent());
TF_RET_CHECK(event.event_->Init()) << "Event initialization failed";
}
stream->ThenRecordEvent(event.event_.get());
return event;
}
} // namespace xla

View File

@ -0,0 +1,76 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_
#include <memory>
#include <stack>
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/stream_executor.h"
namespace xla {
class EventPool {
public:
class Handle {
public:
Handle() = default;
~Handle();
Handle(const Handle&) = delete;
Handle(Handle&&) = default;
Handle& operator=(const Handle&) = delete;
Handle& operator=(Handle&&) = default;
se::Event* event() const { return event_.get(); }
private:
friend class EventPool;
EventPool* pool_ = nullptr;
std::unique_ptr<se::Event> event_;
};
// Initializes a new EventPool. If `allow_reuse` is true, then events will be
// returned to the pool when their handles are deleted and made available to
// subsequent allocations. Reuse only works on the GPU platform.
explicit EventPool(bool allow_reuse);
// Allocates a new (or reused) event from the pool, and records the event on
// `stream`.
//
// Reuse is only possible on GPU. Event allocation and recording are coupled
// in a single operation because on GPU it is recording an event that makes it
// a "new" event. According to the CUDA documentation it is safe to call
// cudaEventRecord even if that event may still be in use on the device; APIs
// such as cudaStreamWaitEvent capture the state of the event at the time of
// the host-side call and are not affected by a later host-side
// cudaEventRecord.
StatusOr<Handle> ThenAllocateAndRecordEvent(se::Stream* stream);
private:
const bool allow_reuse_;
absl::Mutex mu_;
std::stack<std::unique_ptr<se::Event>> free_events_ GUARDED_BY(mu_);
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_

View File

@ -18,22 +18,39 @@ limitations under the License.
// Asynchronous execution: // Asynchronous execution:
// ----------------------- // -----------------------
// //
// If 'asynchronous' is set when constructing the client, computations and // Computations and host-to-device transfers do not need to block the host
// host-to-device transfers do not block the host waiting for the operation to // waiting for the operation to complete but instead return control to the host
// complete but instead return control to the host immediately. This allows // immediately. This allows Python logic to overlap with device-side
// Python logic to overlap with device-side computation. // computation.
// //
// For a good user experience, we must be careful only to enqueue operations // For a good user experience, we must be careful only to enqueue operations
// that are unlikely to fail; as a rule error checking must be done eagerly // that are unlikely to fail; as a rule error checking must be done eagerly
// before returning control to the client. // before returning control to the client.
// //
// The degree to which the client can enqueue operations ahead of the client
// is limited by a semaphore. There are at two modes: asynchronous, where we
// allow the client to enqueue up to 32 executions ahead of the device, and
// synchronous, where we limit the client to having one enqueued operation at
// a time. The value of 32 is arbitrary.
//
// Even in asynchronous mode, it is important that we do not permit
// unbounded queue-ahead. Firstly it is problematic when the user does something
// like the following in Python:
// %timeit run_computation()
// To the timeit logic, op() appears to be extremely cheap since it is deferring
// all of its real work and not blocking, and so the %timeit will run op() many
// (e.g., 10000) times to get better timing resolution, even though in reality
// it may be expensive. Secondly, on CPU the allocator is synchronized with the
// head of the compute stream, and we allocate buffers for all of the enqueued
// programs without any reuse (unlike GPU). This means that the memory usage
// is proportional to the queue size.
//
// Multi-stream execution: // Multi-stream execution:
// ----------------------- // -----------------------
// //
// On certain platforms (e.g., TPU), we use a multistream execution design, // We use a multistream execution design, where different Streams are used for
// where different Streams are used for host-to-device transfers, // host-to-device transfers, device-to-host transfers, and compute. This allows
// device-to-host transfers, and compute. This allows us to overlap transfers on // us to overlap transfers on and off the device with computation.
// and off the device with computation.
// //
// Synchronization between streams occurs via BufferDefinitionEvents that // Synchronization between streams occurs via BufferDefinitionEvents that
// describe when the contents of a logical buffer are known to be valid on // describe when the contents of a logical buffer are known to be valid on
@ -66,9 +83,7 @@ limitations under the License.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "absl/time/time.h" #include "absl/time/time.h"
#include "include/pybind11/pybind11.h" #include "include/pybind11/pybind11.h"
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
@ -78,12 +93,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/python/shared_device_buffer.h" #include "tensorflow/compiler/xla/python/shared_device_buffer.h"
#include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/bfc_allocator.h" #include "tensorflow/core/common_runtime/bfc_allocator.h"
#include "tensorflow/core/common_runtime/gpu/gpu_host_allocator.h"
#include "tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h" #include "tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme.h"
@ -93,98 +108,6 @@ namespace xla {
namespace py = pybind11; namespace py = pybind11;
// Registers a 'fn_capsule' as a CPU custom call target.
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
// "xla._CPU_CUSTOM_CALL_TARGET".
Status RegisterCpuCustomCallTarget(const std::string& fn_name,
py::capsule capsule) {
static const char* const kName = "xla._CPU_CUSTOM_CALL_TARGET";
if (absl::string_view(capsule.name()) != kName) {
return InvalidArgument(
"Argument to RegisterCpuCustomCallTargetRegistry was not a "
"xla._CPU_CUSTOM_CALL_TARGET capsule.");
}
CustomCallTargetRegistry::Global()->Register(
fn_name, static_cast<void*>(capsule), "Host");
return Status::OK();
}
Device::Device(se::StreamExecutor* executor, bool use_multiple_streams,
bool synchronous_deallocation, bool asynchronous)
: use_multiple_streams_(use_multiple_streams),
synchronous_deallocation_(synchronous_deallocation),
asynchronous_(asynchronous) {
compute_stream_ = std::make_shared<se::Stream>(executor);
compute_stream_->Init();
if (use_multiple_streams) {
host_to_device_stream_ = std::make_shared<se::Stream>(executor);
device_to_host_stream_ = std::make_shared<se::Stream>(executor);
callback_stream_ = std::make_shared<se::Stream>(executor);
host_to_device_stream_->Init();
device_to_host_stream_->Init();
callback_stream_->Init();
device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
auto stream = std::make_shared<se::Stream>(executor);
stream->Init();
device_to_device_streams_.push_back(std::move(stream));
}
} else {
callback_stream_ = host_to_device_stream_ = device_to_host_stream_ =
compute_stream_;
device_to_device_streams_.push_back(compute_stream_);
}
worker_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
"py_xla_execute");
}
Device::~Device() {
Status status = SynchronizeAllActivity();
if (!status.ok()) {
LOG(ERROR) << "Error when closing device: " << status;
}
}
Status Device::SynchronizeAllActivity() {
Status status;
// TODO(phawkins): in theory the call to SynchronizeAllActivity below should
// suffice. However on the Host platform SynchronizeAllActivity is a dummy
// implementation that doesn't actually block. To make sure activity has
// stopped, also block on the compute stream. If SynchronizeAllActivity is
// fixed, we could remove the BlockHostUntilDone call.
status.Update(compute_stream_->BlockHostUntilDone());
bool ok = compute_stream_->parent()->SynchronizeAllActivity();
if (!ok) {
status.Update(Unknown("SynchronizeAllActivity failed."));
}
return status;
}
Status Device::ThenMemcpyDeviceToDevice(se::Stream* src_stream,
se::Stream* dst_stream,
se::DeviceMemoryBase src_buffer,
se::DeviceMemoryBase dst_buffer) {
// The default implementation simply calls ThenMemcpyD2D, and assumes that
// the buffer addresses identify the devices. This does not work
// on all platforms; this method is virtual so it can be overridden.
src_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size());
return Status::OK();
}
void Device::ThenExecuteOnWorkerThread(se::Stream* stream,
std::function<void()> callback) const {
stream->ThenDoHostCallback(
[this, callback]() { worker_thread_->Schedule(std::move(callback)); });
}
se::Stream* Device::GetDeviceToDeviceStream() {
absl::MutexLock lock(&mu_);
int i = next_device_to_device_stream_;
next_device_to_device_stream_ =
(next_device_to_device_stream_ + 1) % device_to_device_streams_.size();
return device_to_device_streams_.at(i).get();
}
static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator( static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
se::Platform* platform, LocalClient* client, double memory_fraction, se::Platform* platform, LocalClient* client, double memory_fraction,
bool preallocate) { bool preallocate) {
@ -237,44 +160,57 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
options.set_platform(platform); options.set_platform(platform);
TF_ASSIGN_OR_RETURN(LocalClient * client, TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options)); ClientLibrary::GetOrCreateLocalClient(options));
bool gpu_platform = platform_name == "gpu";
std::unique_ptr<se::DeviceMemoryAllocator> allocator; std::unique_ptr<se::DeviceMemoryAllocator> allocator;
if (allocator_config.kind == AllocatorConfig::Kind::kBFC || std::unique_ptr<tensorflow::Allocator> host_memory_allocator;
(platform_name == "gpu" && if (gpu_platform) {
allocator_config.kind == AllocatorConfig::Kind::kDefault)) { if (allocator_config.kind != AllocatorConfig::Kind::kPlatform) {
if (platform_name != "gpu") {
return Unimplemented("BFCAllocator only available for GPU.");
}
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto bfc_allocator, allocator,
CreateBFCAllocator(platform, client, allocator_config.memory_fraction, CreateBFCAllocator(platform, client, allocator_config.memory_fraction,
allocator_config.preallocate)); allocator_config.preallocate));
allocator = std::move(bfc_allocator); }
tensorflow::SubAllocator* sub_allocator = new tensorflow::GpuHostAllocator(
client->backend().stream_executor(0).ValueOrDie(), /*numa_node=*/0,
/*alloc_visitors=*/{},
/*free_visitors=*/{});
// TODO(phawkins): allow the user to tune this.
const int64 kGpuHostMemoryLimitBytes = 64 * (1LL << 30);
host_memory_allocator = absl::make_unique<tensorflow::BFCAllocator>(
sub_allocator, kGpuHostMemoryLimitBytes, /*allow_growth=*/true,
/*name=*/"xla_gpu_host_bfc");
} else if (allocator_config.kind == AllocatorConfig::Kind::kBFC) {
return Unimplemented("BFCAllocator only available for GPU.");
} }
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
devices.reserve(client->device_count()); devices.reserve(client->device_count());
bool use_multiple_streams = (platform_name != "cpu"); bool synchronous_deallocation = platform_name == "cpu";
bool synchronous_deallocation = !use_multiple_streams;
for (int i = 0; i < client->device_count(); ++i) { for (int i = 0; i < client->device_count(); ++i) {
se::StreamExecutor* executor = se::StreamExecutor* executor =
client->backend().stream_executor(i).ValueOrDie(); client->backend().stream_executor(i).ValueOrDie();
devices.push_back(absl::make_unique<Device>(executor, use_multiple_streams, devices.push_back(absl::make_unique<Device>(
synchronous_deallocation, executor, synchronous_deallocation, asynchronous,
asynchronous)); /*allow_event_reuse=*/gpu_platform));
} }
return std::make_shared<PyLocalClient>(platform_name, client, return std::make_shared<PyLocalClient>(
std::move(devices), platform_name, client, std::move(devices), std::move(allocator),
std::move(allocator), asynchronous); std::move(host_memory_allocator));
} }
PyLocalClient::PyLocalClient( PyLocalClient::PyLocalClient(
std::string platform_name, LocalClient* client, std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<Device>> devices, std::vector<std::unique_ptr<Device>> devices,
std::unique_ptr<se::DeviceMemoryAllocator> allocator, bool asynchronous) std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator)
: platform_name_(std::move(platform_name)), : platform_name_(std::move(platform_name)),
client_(client), client_(client),
devices_(std::move(devices)), devices_(std::move(devices)),
owned_allocator_(std::move(allocator)), owned_allocator_(std::move(allocator)),
host_memory_allocator_(std::move(host_memory_allocator)),
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer", h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
client->device_count()) { client->device_count()) {
if (owned_allocator_ != nullptr) { if (owned_allocator_ != nullptr) {
@ -303,56 +239,6 @@ StatusOr<pybind11::object> PyLocalClient::TransferFromOutfeed(
return LiteralToPython(std::make_shared<Literal>(std::move(literal))); return LiteralToPython(std::make_shared<Literal>(std::move(literal)));
} }
static StatusOr<std::unique_ptr<PyLocalBuffer>> TransferHostToDeviceAsync(
const PythonBufferTree& tree, int device_ordinal,
std::shared_ptr<PyLocalClient> client, const Device& device) {
se::DeviceMemoryAllocator* allocator = client->allocator();
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(
Shape shape, transfer_manager->ChooseCompactLayoutForShape(tree.shape));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
transfer_manager->AllocateScopedShapedBuffer(
shape, allocator, device_ordinal));
TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
device.host_to_device_stream(), buffer));
auto it = tree.leaves.begin();
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(shape)) {
TF_RET_CHECK(it != tree.leaves.end());
ShapedBuffer leaf(
indexed_shape.shape,
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
client->client()->platform(), device_ordinal);
leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {});
if (device.use_multiple_streams() &&
!transfer_manager->CanShapedBufferBeAccessedNow(
device.host_to_device_stream()->parent(), leaf)) {
device.host_to_device_stream()->ThenWaitFor(device.compute_stream());
}
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDeviceAsync(
device.host_to_device_stream(), *it, leaf));
++it;
}
std::shared_ptr<BufferDefinitionEvent> definition_event;
if (device.use_multiple_streams()) {
TF_ASSIGN_OR_RETURN(definition_event,
BufferDefinitionEvent::Create(
device.host_to_device_stream()->parent()));
definition_event->RecordOnStream(device.host_to_device_stream());
}
std::shared_ptr<SharedDeviceBuffer> device_buffer =
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(buffer),
definition_event);
if (device.synchronous_deallocation()) {
device.ThenReleaseOnWorkerThread(device.host_to_device_stream(),
device_buffer);
}
return absl::make_unique<PyLocalBuffer>(shape, std::move(device_buffer),
std::move(client));
}
/* static */ /* static */
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython( StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
const py::object& argument, std::shared_ptr<PyLocalClient> client, const py::object& argument, std::shared_ptr<PyLocalClient> client,
@ -366,82 +252,83 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
// remain live until the transfer is complete. // remain live until the transfer is complete.
auto py_buffer_ref = auto py_buffer_ref =
client->py_ref_manager().ManageReferences(absl::MakeSpan(tree.arrays)); client->py_ref_manager().ManageReferences(absl::MakeSpan(tree.arrays));
tree.arrays.clear();
// We are done manipulating Python objects; release the GIL. // We are done manipulating Python objects; release the GIL.
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
VLOG(1) << "PyLocalBuffer::FromPython: shape: " << tree.shape.ToString() VLOG(1) << "PyLocalBuffer::FromPython: shape: " << tree.shape.ToString()
<< " device ordinal: " << device_ordinal; << " device ordinal: " << device_ordinal;
const Device& device = client->device(device_ordinal); Device* device = &client->device(device_ordinal);
TF_ASSIGN_OR_RETURN(std::unique_ptr<PyLocalBuffer> buffer, TransferManager* transfer_manager =
TransferHostToDeviceAsync(tree, device_ordinal, client->client()->backend().transfer_manager();
std::move(client), device)); se::DeviceMemoryAllocator* allocator = client->allocator();
TF_ASSIGN_OR_RETURN(
Shape shape, transfer_manager->ChooseCompactLayoutForShape(tree.shape));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
transfer_manager->AllocateScopedShapedBuffer(
shape, allocator, device_ordinal));
TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
device->host_to_device_stream(), buffer));
device.ThenRelease(device.host_to_device_stream(), std::move(py_buffer_ref)); std::vector<std::shared_ptr<void>> staging_buffers;
return buffer; staging_buffers.reserve(tree.leaves.size());
} auto it = tree.leaves.begin();
for (const ShapeUtil::IndexedShape& indexed_shape :
/*static */ StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> ShapeUtil::GetLeafShapes(shape)) {
PyLocalBuffer::FromPythonValues( TF_RET_CHECK(it != tree.leaves.end());
const std::vector<std::pair<py::object, int>>& arguments, ShapedBuffer leaf(
std::shared_ptr<PyLocalClient> client) { indexed_shape.shape,
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPythonValues"); transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
int num_arguments = static_cast<int>(arguments.size()); client->client()->platform(), device_ordinal);
std::vector<std::unique_ptr<PyLocalBuffer>> outputs(num_arguments); leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {});
if (num_arguments == 0) { if (!transfer_manager->CanShapedBufferBeAccessedNow(
return outputs; device->host_to_device_stream()->parent(), leaf)) {
device->host_to_device_stream()->ThenWaitFor(device->compute_stream());
} }
struct H2DTransfer { // If applicable on the backend, stage the transfer via host memory
PythonBufferTree tree; // allocated via the host_memory_allocator. On GPU, this is pinned memory.
StatusOr<std::unique_ptr<PyLocalBuffer>> buffer; if (client->host_memory_allocator()) {
PythonRefManager::ManagedPyObjects py_buffer_refs; int64 size = it->size_bytes({});
}; void* ptr = client->host_memory_allocator()->AllocateRaw(
tensorflow::Allocator::kAllocatorAlignment, size);
std::vector<H2DTransfer> transfers(num_arguments); std::shared_ptr<void> staging_buffer(ptr, [client](void* ptr) {
for (int i = 0; i < num_arguments; ++i) { client->host_memory_allocator()->DeallocateRaw(ptr);
TF_ASSIGN_OR_RETURN(transfers[i].tree,
GetPythonBufferTree(arguments[i].first));
transfers[i].py_buffer_refs = client->py_ref_manager().ManageReferences(
absl::MakeSpan(transfers[i].tree.arrays));
}
client->py_ref_manager().CollectGarbage();
// We are done manipulating Python objects; release the GIL.
py::gil_scoped_release gil_release;
auto transfer_h2d = [&](int i) -> StatusOr<std::unique_ptr<PyLocalBuffer>> {
int device_ordinal = arguments[i].second;
return TransferHostToDeviceAsync(transfers[i].tree, device_ordinal, client,
client->device(device_ordinal));
};
// We perform the transfers on a thread pool in case XLA needs to do any
// host-side preprocessing of the input data.
if (num_arguments == 1) {
transfers[0].buffer = transfer_h2d(0);
} else {
absl::BlockingCounter counter(num_arguments);
for (int i = 0; i < num_arguments; ++i) {
client->h2d_transfer_pool()->Schedule([&, i]() {
transfers[i].buffer = transfer_h2d(i);
counter.DecrementCount();
}); });
std::memcpy(ptr, it->untyped_data({}), size);
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
it->shape());
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDeviceAsync(
device->host_to_device_stream(), literal, leaf));
staging_buffers.push_back(std::move(staging_buffer));
} else {
// Otherwise, just transfer the literal.
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDeviceAsync(
device->host_to_device_stream(), *it, leaf));
} }
counter.Wait(); ++it;
} }
// Release our references once the transfers have completed. auto definition_event = std::make_shared<BufferDefinitionEvent>();
for (int i = 0; i < num_arguments; ++i) { TF_ASSIGN_OR_RETURN(EventPool::Handle event,
int device_ordinal = arguments[i].second; device->event_pool().ThenAllocateAndRecordEvent(
const Device& device = client->device(device_ordinal); device->host_to_device_stream()));
device.ThenRelease(device.host_to_device_stream(), definition_event->SetDefinitionEvent(std::move(event),
std::move(transfers[i].py_buffer_refs)); device->host_to_device_stream());
}
for (int i = 0; i < num_arguments; ++i) { std::shared_ptr<SharedDeviceBuffer> device_buffer =
TF_ASSIGN_OR_RETURN(outputs[i], std::move(transfers[i].buffer)); SharedDeviceBuffer::FromScopedShapedBuffer(std::move(buffer),
definition_event);
if (device->synchronous_deallocation()) {
device->ThenRelease(device->host_to_device_stream(), device_buffer);
} }
return outputs; device->ThenRelease(
device->host_to_device_stream(),
std::make_pair(std::move(py_buffer_ref), std::move(staging_buffers)));
return absl::make_unique<PyLocalBuffer>(shape, std::move(device_buffer),
std::move(client));
} }
/* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple( /* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple(
@ -465,13 +352,9 @@ PyLocalBuffer::FromPythonValues(
se::DeviceMemoryAllocator* allocator = client->allocator(); se::DeviceMemoryAllocator* allocator = client->allocator();
TransferManager* transfer_manager = TransferManager* transfer_manager =
client->client()->backend().transfer_manager(); client->client()->backend().transfer_manager();
const Device& device = client->device(device_ordinal); Device& device = client->device(device_ordinal);
std::shared_ptr<BufferDefinitionEvent> definition_event;
if (device.use_multiple_streams()) { auto definition_event = std::make_shared<BufferDefinitionEvent>();
TF_ASSIGN_OR_RETURN(definition_event,
BufferDefinitionEvent::Create(
device.host_to_device_stream()->parent()));
}
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::shared_ptr<SharedDeviceBuffer> tuple_buffer, std::shared_ptr<SharedDeviceBuffer> tuple_buffer,
SharedDeviceBuffer::MakeTuple(device_buffers, transfer_manager, allocator, SharedDeviceBuffer::MakeTuple(device_buffers, transfer_manager, allocator,
@ -482,21 +365,22 @@ PyLocalBuffer::FromPythonValues(
// TODO(phawkins): extend TransferManager so we do not need to form a full // TODO(phawkins): extend TransferManager so we do not need to form a full
// ShapedBuffer just to write the root tuple index table. // ShapedBuffer just to write the root tuple index table.
TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer->AsShapedBuffer()); TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer->AsShapedBuffer());
if (device.use_multiple_streams() && if (!transfer_manager->CanShapedBufferBeAccessedNow(
!transfer_manager->CanShapedBufferBeAccessedNow(
device.host_to_device_stream()->parent(), shaped_buffer)) { device.host_to_device_stream()->parent(), shaped_buffer)) {
// Wait for the compute stream so that memory allocations are synchronized. // Wait for the compute stream so that memory allocations are synchronized.
device.host_to_device_stream()->ThenWaitFor(device.compute_stream()); device.host_to_device_stream()->ThenWaitFor(device.compute_stream());
} }
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable( TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
device.host_to_device_stream(), shaped_buffer)); device.host_to_device_stream(), shaped_buffer));
if (definition_event) {
definition_event->RecordOnStream(device.host_to_device_stream()); TF_ASSIGN_OR_RETURN(EventPool::Handle event,
} device.event_pool().ThenAllocateAndRecordEvent(
device.host_to_device_stream()));
definition_event->SetDefinitionEvent(std::move(event),
device.host_to_device_stream());
if (device.synchronous_deallocation()) { if (device.synchronous_deallocation()) {
device.ThenReleaseOnWorkerThread(device.host_to_device_stream(), device.ThenRelease(device.host_to_device_stream(), std::move(tuple_buffer));
std::move(tuple_buffer));
} }
return buffer; return buffer;
} }
@ -629,8 +513,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
ScopedShapedBuffer dst_buffer, ScopedShapedBuffer dst_buffer,
transfer_manager->AllocateScopedShapedBuffer( transfer_manager->AllocateScopedShapedBuffer(
on_host_shape_, client_->allocator(), dst_device_ordinal)); on_host_shape_, client_->allocator(), dst_device_ordinal));
if (dst_device.use_multiple_streams() && if (!transfer_manager->CanShapedBufferBeAccessedNow(
!transfer_manager->CanShapedBufferBeAccessedNow(
dst_device.compute_stream()->parent(), dst_buffer)) { dst_device.compute_stream()->parent(), dst_buffer)) {
src_device_to_device_stream->ThenWaitFor(dst_device.compute_stream()); src_device_to_device_stream->ThenWaitFor(dst_device.compute_stream());
} }
@ -652,6 +535,10 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
output_buffer)); output_buffer));
} }
// We hold on to the `src_device_buffer` until the transfer is finished.
src_device.ThenRelease(src_device_to_device_stream,
std::move(src_device_buffer));
// Write new tuple buffers. The destination buffers have different addresses, // Write new tuple buffers. The destination buffers have different addresses,
// so we must construct tuple buffers from scratch instead of copying them. // so we must construct tuple buffers from scratch instead of copying them.
if (dst_buffer.on_device_shape().IsTuple()) { if (dst_buffer.on_device_shape().IsTuple()) {
@ -665,13 +552,12 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
dst_device.host_to_device_stream()); dst_device.host_to_device_stream());
} }
std::shared_ptr<BufferDefinitionEvent> definition_event; auto definition_event = std::make_shared<BufferDefinitionEvent>();
if (dst_device.use_multiple_streams()) { TF_ASSIGN_OR_RETURN(EventPool::Handle event,
TF_ASSIGN_OR_RETURN( src_device.event_pool().ThenAllocateAndRecordEvent(
definition_event, src_device_to_device_stream));
BufferDefinitionEvent::Create(src_device_to_device_stream->parent())); definition_event->SetDefinitionEvent(std::move(event),
definition_event->RecordOnStream(src_device_to_device_stream); src_device_to_device_stream);
}
std::shared_ptr<SharedDeviceBuffer> dst_device_buffer = std::shared_ptr<SharedDeviceBuffer> dst_device_buffer =
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(dst_buffer), SharedDeviceBuffer::FromScopedShapedBuffer(std::move(dst_buffer),
@ -756,21 +642,21 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
<< " buffer: " << argument_buffers.back().ToString(); << " buffer: " << argument_buffers.back().ToString();
} }
const Device& device = client_->device(device_ordinal); Device* device = &client_->device(device_ordinal);
// The choice of where we wait in "synchronous" mode is arbitrary; the reason // The choice of where we wait is arbitrary; the reason for the wait is pacing
// for the wait is pacing to avoid problems such as memory fragmentation, not // to avoid problems such as memory fragmentation and running ahead too far,
// for correctness. // not for correctness. Placing it before the executable launch allows the
if (!device.asynchronous()) { // inputs for the next executable to be fetched even if the launch is delayed.
TF_RETURN_IF_ERROR(device.compute_stream()->BlockHostUntilDone()); auto compute_reservation = std::make_shared<Semaphore::ScopedReservation>(
} device->compute_semaphore().ScopedAcquire(1));
for (BufferDefinitionEvent* event : events) { for (BufferDefinitionEvent* event : events) {
event->WaitForEventOnStream(device.compute_stream()); event->WaitForEventOnStream(device->compute_stream());
} }
ExecutableRunOptions options; ExecutableRunOptions options;
options.set_stream(device.compute_stream()); options.set_stream(device->compute_stream());
options.set_host_to_device_stream(device.host_to_device_stream()); options.set_host_to_device_stream(device->host_to_device_stream());
options.set_allocator(client_->allocator()); options.set_allocator(client_->allocator());
options.set_intra_op_thread_pool( options.set_intra_op_thread_pool(
client_->client()->backend().eigen_intra_op_thread_pool_device()); client_->client()->backend().eigen_intra_op_thread_pool_device());
@ -787,24 +673,25 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
return result_buffer.status(); return result_buffer.status();
} }
std::shared_ptr<BufferDefinitionEvent> definition_event; auto definition_event = std::make_shared<BufferDefinitionEvent>();
if (device.use_multiple_streams()) { TF_ASSIGN_OR_RETURN(EventPool::Handle event,
TF_ASSIGN_OR_RETURN( device->event_pool().ThenAllocateAndRecordEvent(
definition_event, device->compute_stream()));
BufferDefinitionEvent::Create(device.compute_stream()->parent())); definition_event->SetDefinitionEvent(std::move(event),
definition_event->RecordOnStream(device.compute_stream()); device->compute_stream());
}
Shape on_host_shape = result_buffer.ValueOrDie().on_host_shape(); Shape on_host_shape = result_buffer.ValueOrDie().on_host_shape();
std::shared_ptr<SharedDeviceBuffer> out_buffer = std::shared_ptr<SharedDeviceBuffer> out_buffer =
SharedDeviceBuffer::FromScopedShapedBuffer( SharedDeviceBuffer::FromScopedShapedBuffer(
std::move(result_buffer.ValueOrDie()), definition_event); std::move(result_buffer.ValueOrDie()), definition_event);
if (device.synchronous_deallocation()) { if (device->synchronous_deallocation()) {
device_buffers.push_back(out_buffer); device_buffers.push_back(out_buffer);
device.ThenReleaseOnWorkerThread(device.compute_stream(), device->ThenRelease(device->compute_stream(), std::move(device_buffers));
std::move(device_buffers));
} }
device.ThenReleaseOnWorkerThread(device.compute_stream(), executable_);
device->ThenRelease(device->compute_stream(),
std::make_pair(executable_, compute_reservation));
return absl::make_unique<PyLocalBuffer>(on_host_shape, std::move(out_buffer), return absl::make_unique<PyLocalBuffer>(on_host_shape, std::move(out_buffer),
client_); client_);
} }
@ -853,7 +740,7 @@ PyLocalExecutable::ExecutePerReplica(
for (int replica = 0; replica < num_replicas(); ++replica) { for (int replica = 0; replica < num_replicas(); ++replica) {
const int device_ordinal = device_assignment_(replica, 0); const int device_ordinal = device_assignment_(replica, 0);
const Device& device = client_->device(device_ordinal); const Device& device = client_->device(device_ordinal);
device.worker_thread()->Schedule([&, replica] { device.execute_thread()->Schedule([&, replica] {
results[replica] = results[replica] =
ExecuteHelper(argument_handles[replica], replica, run_id); ExecuteHelper(argument_handles[replica], replica, run_id);

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_ #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
#include <deque> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
@ -27,137 +27,19 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/python/device.h"
#include "tensorflow/compiler/xla/python/python_ref_manager.h" #include "tensorflow/compiler/xla/python/python_ref_manager.h"
#include "tensorflow/compiler/xla/python/shared_device_buffer.h" #include "tensorflow/compiler/xla/python/shared_device_buffer.h"
#include "tensorflow/compiler/xla/python/worker_thread.h"
#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/core/status.h"
namespace xla { namespace xla {
// Registers a 'fn_capsule' as a CPU custom call target.
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
// "xla._CPU_CUSTOM_CALL_TARGET".
Status RegisterCpuCustomCallTarget(const std::string& fn_name,
pybind11::capsule capsule);
// Class that encapsulates state relating to a device (e.g., a GPU) on which we
// can perform computation and transfers.
class Device {
public:
// If use_multiple_streams is true, we allocate separate streams for compute
// and transfers. If it is false, we share a single stream for compute and
// transfers. The CPU device does not support multiple streams, and this is
// a workaround until it does.
//
// If synchronous_deallocation is true, the host must not free buffers until
// compute/transfers that use those buffers have completed. For example, this
// typically is the case for the "platform" where compute/transfers are
// operations that take place on another thread.
//
// If asynchronous is false, the host will synchronize to the device after
// each execution or transfer. This is intended for debugging only.
Device(se::StreamExecutor* executor, bool use_multiple_streams,
bool synchronous_deallocation, bool asynchronous);
virtual ~Device();
bool use_multiple_streams() const { return use_multiple_streams_; }
bool synchronous_deallocation() const { return synchronous_deallocation_; }
bool asynchronous() const { return asynchronous_; }
se::Stream* compute_stream() const { return compute_stream_.get(); }
se::Stream* host_to_device_stream() const {
return host_to_device_stream_.get();
}
se::Stream* device_to_host_stream() const {
return device_to_host_stream_.get();
}
// Returns a device to device stream. Allocates streams in a round-robin
// fashion amongst the available streams.
se::Stream* GetDeviceToDeviceStream();
// Enqueues a copy of `src_buffer` to `dst_buffer` onto `src_stream`.
virtual Status ThenMemcpyDeviceToDevice(se::Stream* src_stream,
se::Stream* dst_stream,
se::DeviceMemoryBase src_buffer,
se::DeviceMemoryBase dst_buffer);
// A worker thread, used for replicated computation launches and callbacks.
WorkerThread* worker_thread() const { return worker_thread_.get(); }
// Enqueues a host callback on 'stream', to be executed by worker_thread_.
// ThenDoHostCallback is often constrained in what it can do, in particular,
// on GPU the callback runs on a thread belonging to the GPU runtime and
// cannot perform GPU operations itself.
void ThenExecuteOnWorkerThread(se::Stream* stream,
std::function<void()> callback) const;
// Helper for releasing values from a callback at the tail of a stream.
// This is only permitted if object's destructor will not free any device
// objects, since the callback may be called from a device thread pool on
// GPU.
template <typename T>
void ThenRelease(se::Stream* stream, T object) const {
if (callback_stream_.get() != stream) {
callback_stream_->ThenWaitFor(stream);
}
callback_stream_->ThenDoHostCallback(
std::bind([](T& object) { /* releases object */ }, std::move(object)));
}
// Helpers for releasing values on a worker thread at the tail of a stream on
// a worker thread.
template <typename T>
void ThenReleaseOnWorkerThread(se::Stream* stream,
std::shared_ptr<T> object) const {
// We use a non-smart pointer here because we want to ensure that the worker
// thread is the only callee of the shared_ptr destructor, and if we passed
// object by lambda capture we have a race where the worker thread might
// run and release its reference first.
auto* ref = new std::shared_ptr<T>(std::move(object));
if (callback_stream_.get() != stream) {
callback_stream_->ThenWaitFor(stream);
}
ThenExecuteOnWorkerThread(callback_stream_.get(), [ref]() { delete ref; });
}
template <typename T>
void ThenReleaseOnWorkerThread(se::Stream* stream,
std::vector<std::shared_ptr<T>> object) const {
auto* ref = new std::vector<std::shared_ptr<T>>(std::move(object));
if (callback_stream_.get() != stream) {
callback_stream_->ThenWaitFor(stream);
}
ThenExecuteOnWorkerThread(callback_stream_.get(), [ref]() { delete ref; });
}
private:
Status SynchronizeAllActivity();
bool use_multiple_streams_;
bool synchronous_deallocation_;
bool asynchronous_;
std::shared_ptr<se::Stream> compute_stream_;
std::shared_ptr<se::Stream> host_to_device_stream_;
std::shared_ptr<se::Stream> device_to_host_stream_;
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_;
// Number of device-to-device streams to create in the multistream case.
static constexpr int kNumDeviceToDeviceStreams = 4;
absl::Mutex mu_;
int next_device_to_device_stream_ GUARDED_BY(mu_) = 0;
// Callback stream is used for running short host-side callbacks after device
// side events, without preventing the device-side stream from doing useful
// work.
std::shared_ptr<se::Stream> callback_stream_;
std::unique_ptr<WorkerThread> worker_thread_;
};
struct AllocatorConfig { struct AllocatorConfig {
enum class Kind { enum class Kind {
kDefault, // Client picks the best option for the platform. kDefault, // Client picks the best option for the platform.
@ -188,10 +70,11 @@ class PyLocalClient {
bool asynchronous, const AllocatorConfig& allocator_config); bool asynchronous, const AllocatorConfig& allocator_config);
// `allocator` may null, in which case the platform default allocator is used. // `allocator` may null, in which case the platform default allocator is used.
explicit PyLocalClient(std::string platform_name, LocalClient* client, explicit PyLocalClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<Device>> devices, std::vector<std::unique_ptr<Device>> devices,
std::unique_ptr<se::DeviceMemoryAllocator> allocator, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
bool asynchronous); std::unique_ptr<tensorflow::Allocator> host_memory_allocator);
virtual ~PyLocalClient() = default; virtual ~PyLocalClient() = default;
Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal); Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal);
@ -204,6 +87,9 @@ class PyLocalClient {
} }
LocalClient* client() const { return client_; } LocalClient* client() const { return client_; }
se::DeviceMemoryAllocator* allocator() const { return allocator_; } se::DeviceMemoryAllocator* allocator() const { return allocator_; }
tensorflow::Allocator* host_memory_allocator() const {
return host_memory_allocator_.get();
}
tensorflow::thread::ThreadPool* h2d_transfer_pool() { tensorflow::thread::ThreadPool* h2d_transfer_pool() {
return &h2d_transfer_pool_; return &h2d_transfer_pool_;
@ -225,6 +111,11 @@ class PyLocalClient {
se::DeviceMemoryAllocator* allocator_; se::DeviceMemoryAllocator* allocator_;
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_; std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
// Allocator to be used for staging memory transfers to devices. Optional;
// only used on GPU where it is more efficient to copy buffers to and from the
// device via a staging area of pinned memory.
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
tensorflow::thread::ThreadPool h2d_transfer_pool_; tensorflow::thread::ThreadPool h2d_transfer_pool_;
}; };
@ -241,12 +132,6 @@ class PyLocalBuffer {
const pybind11::object& argument, std::shared_ptr<PyLocalClient> client, const pybind11::object& argument, std::shared_ptr<PyLocalClient> client,
int device_ordinal); int device_ordinal);
// Converts multiple (python object, device ordinal) pairs into
// PyLocalBuffers in parallel.
static StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> FromPythonValues(
const std::vector<std::pair<pybind11::object, int>>& argument,
std::shared_ptr<PyLocalClient> client);
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple( static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
const std::vector<PyLocalBuffer*> buffers, const std::vector<PyLocalBuffer*> buffers,
std::shared_ptr<PyLocalClient> client, int device_ordinal); std::shared_ptr<PyLocalClient> client, int device_ordinal);

View File

@ -37,9 +37,9 @@ PythonRefManager::ManagedPyObjects::~ManagedPyObjects() {
} }
} }
PythonRefManager::ManagedPyObjects PythonRefManager::ManageReferences( std::shared_ptr<PythonRefManager::ManagedPyObjects>
absl::Span<py::object> objects) { PythonRefManager::ManageReferences(absl::Span<py::object> objects) {
return ManagedPyObjects(this, objects); return std::make_shared<ManagedPyObjects>(this, objects);
} }
void PythonRefManager::CollectGarbage() { void PythonRefManager::CollectGarbage() {

View File

@ -48,9 +48,9 @@ class PythonRefManager {
~ManagedPyObjects(); ~ManagedPyObjects();
ManagedPyObjects(const ManagedPyObjects& other) = default; ManagedPyObjects(const ManagedPyObjects& other) = delete;
ManagedPyObjects(ManagedPyObjects&& other) = default; ManagedPyObjects(ManagedPyObjects&& other) = default;
ManagedPyObjects& operator=(const ManagedPyObjects& other) = default; ManagedPyObjects& operator=(const ManagedPyObjects& other) = delete;
ManagedPyObjects& operator=(ManagedPyObjects&& other) = default; ManagedPyObjects& operator=(ManagedPyObjects&& other) = default;
private: private:
@ -61,7 +61,8 @@ class PythonRefManager {
// Creates a managed std::shared_ptr to an object. When the shared_ptr is // Creates a managed std::shared_ptr to an object. When the shared_ptr is
// destroyed, the reference to 'object' will be added to python_garbage_, // destroyed, the reference to 'object' will be added to python_garbage_,
// and collected next time CollectGarbage() is called. // and collected next time CollectGarbage() is called.
ManagedPyObjects ManageReferences(absl::Span<pybind11::object> objects); std::shared_ptr<ManagedPyObjects> ManageReferences(
absl::Span<pybind11::object> objects);
// Releases the contents of python_garbage_. Requires that the GIL is held. // Releases the contents of python_garbage_. Requires that the GIL is held.
// The client calls this method during API entry points where the GIL is held // The client calls this method during API entry points where the GIL is held

View File

@ -0,0 +1,74 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/semaphore.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
Semaphore::Semaphore(int64 capacity) : value_(capacity) {
CHECK_GE(capacity, 0);
}
bool Semaphore::CanAcquire(CanAcquireArgs* args) {
return args->semaphore->value_ >= args->amount;
}
void Semaphore::Acquire(int64 amount) {
CHECK_GE(amount, 0);
CanAcquireArgs args;
args.semaphore = this;
args.amount = amount;
mu_.LockWhen(absl::Condition(&CanAcquire, &args));
value_ -= amount;
mu_.Unlock();
}
void Semaphore::Release(int64 amount) {
CHECK_GE(amount, 0);
absl::MutexLock lock(&mu_);
value_ += amount;
}
Semaphore::ScopedReservation::~ScopedReservation() {
if (semaphore_) {
semaphore_->Release(amount_);
}
}
Semaphore::ScopedReservation::ScopedReservation(
ScopedReservation&& other) noexcept {
semaphore_ = other.semaphore_;
amount_ = other.amount_;
other.semaphore_ = nullptr;
}
Semaphore::ScopedReservation& Semaphore::ScopedReservation::operator=(
ScopedReservation&& other) noexcept {
semaphore_ = other.semaphore_;
amount_ = other.amount_;
other.semaphore_ = nullptr;
return *this;
}
Semaphore::ScopedReservation Semaphore::ScopedAcquire(int64 amount) {
Acquire(amount);
return ScopedReservation(this, amount);
}
} // namespace xla

View File

@ -0,0 +1,67 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
class Semaphore {
public:
explicit Semaphore(int64 capacity);
// Acquires `amount` units. Blocks until `amount` units are available.
void Acquire(int64 amount);
// Returns `amount` units to the semaphore.
void Release(int64 amount);
class ScopedReservation {
public:
ScopedReservation(Semaphore* semaphore, int64 amount)
: semaphore_(semaphore), amount_(amount) {}
~ScopedReservation();
ScopedReservation(const ScopedReservation&) = delete;
ScopedReservation(ScopedReservation&& other) noexcept;
ScopedReservation& operator=(const ScopedReservation&) = delete;
ScopedReservation& operator=(ScopedReservation&& other) noexcept;
private:
Semaphore* semaphore_;
int64 amount_;
};
// RAII version of Acquire. Releases the reservation when the
// ScopedReservation is destroyed.
ScopedReservation ScopedAcquire(int64 amount);
private:
struct CanAcquireArgs {
Semaphore* semaphore;
int64 amount;
};
static bool CanAcquire(CanAcquireArgs* args)
EXCLUSIVE_LOCKS_REQUIRED(args->semaphore->mu_);
absl::Mutex mu_;
int64 value_ GUARDED_BY(mu_);
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_

View File

@ -0,0 +1,74 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/semaphore.h"
#include "absl/synchronization/notification.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/env.h"
namespace xla {
namespace {
TEST(SemaphoreTest, UnthreadedTests) {
Semaphore semaphore(2);
semaphore.Acquire(1);
semaphore.Release(1);
semaphore.Acquire(2);
semaphore.Release(2);
semaphore.Acquire(1);
semaphore.Acquire(1);
semaphore.Release(1);
semaphore.Acquire(1);
semaphore.Release(1);
semaphore.Acquire(1);
semaphore.Release(2);
{
auto a = semaphore.ScopedAcquire(1);
{ auto b = semaphore.ScopedAcquire(1); }
{ auto c = semaphore.ScopedAcquire(1); }
}
}
TEST(SemaphoreTest, ConcurrentTest) {
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2);
Semaphore semaphore(2);
semaphore.Acquire(1);
absl::Notification a_done;
pool.Schedule([&]() {
semaphore.Acquire(2);
semaphore.Release(2);
a_done.Notify();
});
absl::Notification b_done;
pool.Schedule([&]() {
semaphore.Acquire(1);
semaphore.Release(1);
b_done.Notify();
});
b_done.WaitForNotification();
EXPECT_FALSE(a_done.HasBeenNotified());
semaphore.Release(1);
a_done.WaitForNotification();
}
} // namespace
} // namespace xla

View File

@ -21,21 +21,12 @@ limitations under the License.
namespace xla { namespace xla {
/*static*/ StatusOr<std::shared_ptr<BufferDefinitionEvent>> void BufferDefinitionEvent::SetDefinitionEvent(EventPool::Handle event,
BufferDefinitionEvent::Create(se::StreamExecutor* executor) { se::Stream* stream) {
auto event = std::make_shared<BufferDefinitionEvent>(executor);
TF_RET_CHECK(event->event_.Init())
<< "Buffer definition event initialization failed";
return event;
}
BufferDefinitionEvent::BufferDefinitionEvent(se::StreamExecutor* executor)
: event_(executor) {}
void BufferDefinitionEvent::RecordOnStream(se::Stream* stream) {
absl::MutexLock lock(&mu_); absl::MutexLock lock(&mu_);
CHECK(!event_.event());
event_ = std::move(event);
CHECK(streams_defined_on_.empty()); CHECK(streams_defined_on_.empty());
stream->ThenRecordEvent(&event_);
streams_defined_on_.push_back(stream); streams_defined_on_.push_back(stream);
} }
@ -50,7 +41,7 @@ void BufferDefinitionEvent::WaitForEventOnStream(se::Stream* stream) {
return; return;
} }
stream->ThenWaitFor(&event_); stream->ThenWaitFor(event_.event());
streams_defined_on_.push_back(stream); streams_defined_on_.push_back(stream);
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/python/event_pool.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape.h"
@ -50,14 +51,11 @@ namespace xla {
// same stream causes no additional waiting. // same stream causes no additional waiting.
class BufferDefinitionEvent { class BufferDefinitionEvent {
public: public:
// Creates a new definition event whose event has not yet been triggered. BufferDefinitionEvent() = default;
static StatusOr<std::shared_ptr<BufferDefinitionEvent>> Create(
se::StreamExecutor* executor);
explicit BufferDefinitionEvent(se::StreamExecutor* executor); // Sets the definition event of the buffer to 'event', which is recorded
// on 'stream'. Must be called at most once.
// Records the definition event on the tail of 'stream'. void SetDefinitionEvent(EventPool::Handle event, se::Stream* stream);
void RecordOnStream(se::Stream* stream);
// Adds synchronization events to 'stream' that wait for this event to be // Adds synchronization events to 'stream' that wait for this event to be
// defined on 'stream'. Does nothing if the event is already known to have // defined on 'stream'. Does nothing if the event is already known to have
@ -68,7 +66,7 @@ class BufferDefinitionEvent {
// An event that is triggered when the content of one or more buffers is // An event that is triggered when the content of one or more buffers is
// ready. If this event is nullptr, it is assumed that the buffer's content is // ready. If this event is nullptr, it is assumed that the buffer's content is
// always defined. // always defined.
se::Event event_; EventPool::Handle event_;
absl::Mutex mu_; absl::Mutex mu_;

View File

@ -13,10 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <cstdint>
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/base/casts.h"
#include "absl/hash/hash.h" #include "absl/hash/hash.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "absl/types/span.h" #include "absl/types/span.h"
@ -34,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/local_client.h" #include "tensorflow/compiler/xla/python/local_client.h"
#include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/python/xrt.h" #include "tensorflow/compiler/xla/python/xrt.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
@ -104,6 +108,22 @@ StatusOr<std::string> GetComputationHloDotGraph(
RenderedGraphFormat::kDot); RenderedGraphFormat::kDot);
} }
// Registers a 'fn_capsule' as a CPU custom call target.
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
// "xla._CPU_CUSTOM_CALL_TARGET".
Status RegisterCpuCustomCallTarget(const std::string& fn_name,
py::capsule capsule) {
static const char* const kName = "xla._CPU_CUSTOM_CALL_TARGET";
if (absl::string_view(capsule.name()) != kName) {
return InvalidArgument(
"Argument to RegisterCpuCustomCallTargetRegistry was not a "
"xla._CPU_CUSTOM_CALL_TARGET capsule.");
}
CustomCallTargetRegistry::Global()->Register(
fn_name, static_cast<void*>(capsule), "Host");
return Status::OK();
}
} // namespace } // namespace
PYBIND11_MODULE(xla_extension, m) { PYBIND11_MODULE(xla_extension, m) {
@ -297,7 +317,6 @@ PYBIND11_MODULE(xla_extension, m) {
py::class_<PyLocalBuffer>(m, "PyLocalBuffer") py::class_<PyLocalBuffer>(m, "PyLocalBuffer")
.def_static("from_python", &PyLocalBuffer::FromPython) .def_static("from_python", &PyLocalBuffer::FromPython)
.def_static("from_python_values", &PyLocalBuffer::FromPythonValues)
.def_static("make_tuple", &PyLocalBuffer::MakeTuple) .def_static("make_tuple", &PyLocalBuffer::MakeTuple)
.def("copy_to_device", &PyLocalBuffer::CopyToDevice) .def("copy_to_device", &PyLocalBuffer::CopyToDevice)
.def("delete", &PyLocalBuffer::Delete) .def("delete", &PyLocalBuffer::Delete)
@ -307,8 +326,21 @@ PYBIND11_MODULE(xla_extension, m) {
.def("to_py", &PyLocalBuffer::ToPython) .def("to_py", &PyLocalBuffer::ToPython)
.def("shape", &PyLocalBuffer::on_host_shape) .def("shape", &PyLocalBuffer::on_host_shape)
.def("device", &PyLocalBuffer::device_ordinal) .def("device", &PyLocalBuffer::device_ordinal)
.def("is_deleted", [](const PyLocalBuffer& buffer) { .def("is_deleted",
[](const PyLocalBuffer& buffer) {
return buffer.DeviceBuffer() == nullptr; return buffer.DeviceBuffer() == nullptr;
})
.def("unsafe_buffer_pointer",
[](const PyLocalBuffer& buffer) -> StatusOr<std::uintptr_t> {
TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer,
buffer.AsShapedBuffer());
if (shaped_buffer.on_device_shape().IsTuple()) {
return Unimplemented(
"unsafe_buffer_pointer is not implemented for tuple "
"buffers.");
}
return absl::bit_cast<std::uintptr_t>(
shaped_buffer.root_buffer().opaque());
}); });
py::class_<PyLocalExecutable>(m, "LocalExecutable") py::class_<PyLocalExecutable>(m, "LocalExecutable")

View File

@ -98,10 +98,6 @@ class LocalBackend(Backend):
def buffer_from_pyval(self, pyval, device=0): def buffer_from_pyval(self, pyval, device=0):
return _xla.PyLocalBuffer.from_python(pyval, self.client, device) return _xla.PyLocalBuffer.from_python(pyval, self.client, device)
def buffers_from_pyvals(self, pyvals_and_devices):
return _xla.PyLocalBuffer.from_python_values(pyvals_and_devices,
self.client)
def make_tuple(self, c_buffers, device_ordinal): def make_tuple(self, c_buffers, device_ordinal):
return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device_ordinal) return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device_ordinal)

View File

@ -644,6 +644,7 @@ cc_library(
hdrs = ["call_inliner.h"], hdrs = ["call_inliner.h"],
deps = [ deps = [
":call_graph", ":call_graph",
":hlo",
":hlo_dce", ":hlo_dce",
":hlo_pass", ":hlo_pass",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
@ -1095,50 +1096,6 @@ tf_cc_test(
], ],
) )
cc_library(
name = "buffer_liveness",
srcs = [
"buffer_liveness.cc",
],
hdrs = [
"buffer_liveness.h",
],
deps = [
":hlo",
":hlo_ordering",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
)
tf_cc_test(
name = "buffer_liveness_test",
srcs = ["buffer_liveness_test.cc"],
deps = [
":buffer_liveness",
":hlo",
":hlo_dataflow_analysis",
":hlo_parser",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/memory",
],
)
cc_library( cc_library(
name = "buffer_assignment", name = "buffer_assignment",
srcs = [ srcs = [
@ -1148,7 +1105,6 @@ cc_library(
"buffer_assignment.h", "buffer_assignment.h",
], ],
deps = [ deps = [
":buffer_liveness",
":buffer_value_containers", ":buffer_value_containers",
":heap_simulator", ":heap_simulator",
":hlo", ":hlo",
@ -2785,7 +2741,6 @@ cc_library(
srcs = ["copy_insertion.cc"], srcs = ["copy_insertion.cc"],
hdrs = ["copy_insertion.h"], hdrs = ["copy_insertion.h"],
deps = [ deps = [
":buffer_liveness",
":dump", ":dump",
":hlo", ":hlo",
":hlo_alias_analysis", ":hlo_alias_analysis",
@ -2907,7 +2862,6 @@ cc_library(
srcs = ["hlo_rematerialization.cc"], srcs = ["hlo_rematerialization.cc"],
hdrs = ["hlo_rematerialization.h"], hdrs = ["hlo_rematerialization.h"],
deps = [ deps = [
":buffer_liveness",
":buffer_value", ":buffer_value",
":call_graph", ":call_graph",
":flatten_call_graph", ":flatten_call_graph",
@ -3626,7 +3580,6 @@ cc_library(
srcs = ["reduce_precision_insertion.cc"], srcs = ["reduce_precision_insertion.cc"],
hdrs = ["reduce_precision_insertion.h"], hdrs = ["reduce_precision_insertion.h"],
deps = [ deps = [
":buffer_liveness",
":hlo", ":hlo",
":hlo_pass", ":hlo_pass",
":hlo_pass_pipeline", ":hlo_pass_pipeline",
@ -3946,7 +3899,6 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",

View File

@ -168,13 +168,8 @@ bool IsUnstridedSlice(const HloInstruction* hlo) {
// algebraic expressions to simplified forms. Note: This only supports // algebraic expressions to simplified forms. Note: This only supports
// simplifications that simply look at the operands of an instruction. For the // simplifications that simply look at the operands of an instruction. For the
// more general case a worklist based approach would be needed. // more general case a worklist based approach would be needed.
class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
public: public:
// Default visitor action is to do nothing and return OK.
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
return Status::OK();
}
Status HandleAdd(HloInstruction* add) override; Status HandleAdd(HloInstruction* add) override;
Status HandleAnd(HloInstruction* logical_and) override; Status HandleAnd(HloInstruction* logical_and) override;
@ -250,9 +245,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleMap(HloInstruction* map) override; Status HandleMap(HloInstruction* map) override;
// Returns whether algebraic simplification has occurred.
const bool changed() const { return changed_; }
// Runs the visitor on a computation. // Runs the visitor on a computation.
static bool Run(HloComputation* computation, static bool Run(HloComputation* computation,
const AlgebraicSimplifierOptions& options, const AlgebraicSimplifierOptions& options,
@ -350,35 +342,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
HloInstruction* broadcast); HloInstruction* broadcast);
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceWithNewInstruction(
HloInstruction* old_instruction,
std::unique_ptr<HloInstruction> new_instruction) {
VLOG(3) << "Replacing instruction:";
VLOG(3) << " old: " << old_instruction->ToString();
VLOG(3) << " new: " << new_instruction->ToString();
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
old_instruction, std::move(new_instruction)));
changed_ = true;
return Status::OK();
}
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceInstruction(HloInstruction* old_instruction,
HloInstruction* new_instruction) {
VLOG(3) << "Replacing instruction:";
VLOG(3) << " old: " << old_instruction->ToString();
VLOG(3) << " new: " << new_instruction->ToString();
TF_RETURN_IF_ERROR(
computation_->ReplaceInstruction(old_instruction, new_instruction));
changed_ = true;
return Status::OK();
}
StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot); StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
StatusOr<HloInstruction*> OptimizeDotOfConcatHelper( StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim, const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
@ -445,7 +408,7 @@ bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
AlgebraicSimplifier* simplifier) { AlgebraicSimplifier* simplifier) {
AlgebraicSimplifierVisitor visitor(computation, options, simplifier); AlgebraicSimplifierVisitor visitor(computation, options, simplifier);
TF_CHECK_OK(computation->Accept(&visitor)); TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_; return visitor.changed_ || visitor.changed();
} }
bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
@ -1723,6 +1686,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
} }
Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
CHECK(computation_ == dot->parent());
HloInstruction *lhs, *rhs; HloInstruction *lhs, *rhs;
CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
if (options_.is_layout_sensitive()) { if (options_.is_layout_sensitive()) {
@ -2660,6 +2624,11 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
HloInstruction *a, *b; HloInstruction *a, *b;
CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
// (A % B) % B == A % B.
if (Match(a, m::Remainder(m::Op(), m::Op().Is(b)))) {
return ReplaceInstruction(remainder, a);
}
// A % B => A & (B - 1) if B is a power of 2. // A % B => A & (B - 1) if B is a power of 2.
switch (remainder->shape().element_type()) { switch (remainder->shape().element_type()) {
case S8: case S8:

View File

@ -5503,5 +5503,20 @@ TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIotaOverflow) {
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
} }
TEST_F(AlgebraicSimplifierTest, RepeatedRemainder) {
const char* kModuleStr = R"(
HloModule m
test {
p = s32[1000] parameter(0)
q = s32[1000] parameter(1)
r = s32[1000] remainder(p, q)
ROOT rr = s32[1000] remainder(r, q)
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Remainder(m::Parameter(), m::Parameter())));
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -46,13 +46,8 @@ using absl::optional;
// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
// operations into smaller operations. // operations into smaller operations.
class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { class BatchNormExpanderVisitor : public DfsHloRewriteVisitor {
public: public:
// Default visitor action is to do nothing and return OK.
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
return Status::OK();
}
Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
Status HandleBatchNormInference(HloInstruction* batch_norm) override; Status HandleBatchNormInference(HloInstruction* batch_norm) override;
@ -63,9 +58,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
static bool Run(HloComputation* computation, bool rewrite_training_op, static bool Run(HloComputation* computation, bool rewrite_training_op,
bool rewrite_inference_op, bool rewrite_grad_op); bool rewrite_inference_op, bool rewrite_grad_op);
// Returns whether any batch norm ops were rewritten.
const bool changed() const { return changed_; }
~BatchNormExpanderVisitor() override = default; ~BatchNormExpanderVisitor() override = default;
private: private:
@ -133,28 +125,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
elements_per_feature_u32); elements_per_feature_u32);
} }
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceWithNewInstruction(
HloInstruction* old_instruction,
std::unique_ptr<HloInstruction> new_instruction) {
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
old_instruction, std::move(new_instruction)));
changed_ = true;
return Status::OK();
}
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceInstruction(HloInstruction* old_instruction,
HloInstruction* new_instruction) {
TF_RETURN_IF_ERROR(
computation_->ReplaceInstruction(old_instruction, new_instruction));
changed_ = true;
return Status::OK();
}
// Current HloComputation instance the BatchNormExpander is // Current HloComputation instance the BatchNormExpander is
// traversing. // traversing.
HloComputation* computation_; HloComputation* computation_;
@ -162,9 +132,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
bool rewrite_training_op_; bool rewrite_training_op_;
bool rewrite_inference_op_; bool rewrite_inference_op_;
bool rewrite_grad_op_; bool rewrite_grad_op_;
// Whether rewrite has occurred.
bool changed_ = false;
}; };
} // namespace } // namespace
@ -179,7 +146,7 @@ bool BatchNormExpanderVisitor::Run(HloComputation* computation,
/*rewrite_inference_op=*/rewrite_inference_op, /*rewrite_inference_op=*/rewrite_inference_op,
/*rewrite_grad_op=*/rewrite_grad_op); /*rewrite_grad_op=*/rewrite_grad_op);
TF_CHECK_OK(computation->Accept(&visitor)); TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_; return visitor.changed();
} }
Status BatchNormExpanderVisitor::HandleBatchNormTraining( Status BatchNormExpanderVisitor::HandleBatchNormTraining(

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" #include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h"

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h"

View File

@ -656,7 +656,7 @@ Status BufferAssignment::ComputeSummaryStats() {
for (const auto& computation : module_->computations()) { for (const auto& computation : module_->computations()) {
if (!computation->IsFusionComputation()) { if (!computation->IsFusionComputation()) {
const HloInstructionSequence* sequence = const HloInstructionSequence* sequence =
liveness_->hlo_ordering().SequentialOrder(*computation); hlo_ordering().SequentialOrder(*computation);
if (sequence == nullptr) { if (sequence == nullptr) {
schedule_complete = false; schedule_complete = false;
} else { } else {
@ -833,7 +833,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
const HloValue& assigned_buffer = const HloValue& assigned_buffer =
*CHECK_NOTNULL(dynamic_cast<const HloValue*>(buffer_offset_size.first)); *CHECK_NOTNULL(dynamic_cast<const HloValue*>(buffer_offset_size.first));
for (const HloValue* new_value : hlo_buffer.values()) { for (const HloValue* new_value : hlo_buffer.values()) {
if (assignment->liveness().hlo_ordering().MayInterfere( if (assignment->hlo_ordering().MayInterfere(
assigned_buffer, *new_value, assignment->dataflow_analysis())) { assigned_buffer, *new_value, assignment->dataflow_analysis())) {
VLOG(4) << "Can't assign: assignee " << assigned_buffer VLOG(4) << "Can't assign: assignee " << assigned_buffer
<< " may interfere with " << new_value; << " may interfere with " << new_value;
@ -917,7 +917,7 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
for (const HloValue* instruction_value : instruction_buffer.values()) { for (const HloValue* instruction_value : instruction_buffer.values()) {
for (const HloValue* operand_value : operand_buffer.values()) { for (const HloValue* operand_value : operand_buffer.values()) {
if (assignment->liveness().hlo_ordering().MayInterfere( if (assignment->hlo_ordering().MayInterfere(
*instruction_value, *operand_value, *instruction_value, *operand_value,
assignment->dataflow_analysis())) { assignment->dataflow_analysis())) {
interfere = true; interfere = true;
@ -1047,8 +1047,7 @@ Status BufferAssigner::AssignSingleHloBuffer(
for (const HloValue* hlo_value : hlo_buffer->values()) { for (const HloValue* hlo_value : hlo_buffer->values()) {
HloComputation* computation = hlo_value->instruction()->parent(); HloComputation* computation = hlo_value->instruction()->parent();
const bool has_sequential_order = const bool has_sequential_order =
assignment->liveness().hlo_ordering().SequentialOrder(*computation) != assignment->hlo_ordering().SequentialOrder(*computation) != nullptr;
nullptr;
all_computations_have_sequential_order &= has_sequential_order; all_computations_have_sequential_order &= has_sequential_order;
} }
@ -1125,10 +1124,9 @@ Status BufferAssigner::AssignBuffersForComputations(
} }
} }
const BufferLiveness& liveness = assignment->liveness();
for (const HloComputation* computation : computations) { for (const HloComputation* computation : computations) {
const bool has_sequential_order = const bool has_sequential_order =
liveness.hlo_ordering().SequentialOrder(*computation) != nullptr; assignment->hlo_ordering().SequentialOrder(*computation) != nullptr;
if (has_sequential_order && buffers_to_assign_sequentially != nullptr) { if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
// Every sequential computation must get an entry in the // Every sequential computation must get an entry in the
// buffers_to_assign_sequentially map, even if we end up with an empty // buffers_to_assign_sequentially map, even if we end up with an empty
@ -1197,7 +1195,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
// Run the sequence of instructions through the heap simulator. The // Run the sequence of instructions through the heap simulator. The
// heuristic that seems to give the best results is lazy-best-fit, with all // heuristic that seems to give the best results is lazy-best-fit, with all
// runs of alloc / free calls sorted in decreasing size order. // runs of alloc / free calls sorted in decreasing size order.
const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); const HloOrdering& hlo_ordering = assignment->hlo_ordering();
// Returns a heap algorithm that chooses the best result from several // Returns a heap algorithm that chooses the best result from several
// algorithms. // algorithms.
@ -1392,9 +1390,6 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
BufferValue::SizeFunction buffer_size, BufferValue::SizeFunction buffer_size,
LogicalBuffer::AlignmentFunction color_alignment, LogicalBuffer::AlignmentFunction color_alignment,
HloDataflowAnalysis::CanShareBuffer can_share_buffer) { HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferLiveness> liveness,
BufferLiveness::Run(module, std::move(hlo_ordering)));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module, can_share_buffer)); HloAliasAnalysis::Run(module, can_share_buffer));
@ -1408,11 +1403,11 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
// Can't use absl::make_unique because BufferAssignment constructor is // Can't use absl::make_unique because BufferAssignment constructor is
// private. // private.
std::unique_ptr<BufferAssignment> assignment(new BufferAssignment( std::unique_ptr<BufferAssignment> assignment(new BufferAssignment(
module, std::move(liveness), std::move(buffer_size), module, std::move(hlo_ordering), std::move(buffer_size),
std::move(color_alignment), std::move(alias_analysis))); std::move(color_alignment), std::move(alias_analysis)));
TF_RETURN_IF_ERROR(colorer_(&assignment->alias_analysis(), TF_RETURN_IF_ERROR(
assignment->liveness().hlo_ordering())); colorer_(&assignment->alias_analysis(), assignment->hlo_ordering()));
VLOG(3) << "After coloring:"; VLOG(3) << "After coloring:";
XLA_VLOG_LINES(3, XLA_VLOG_LINES(3,
assignment->alias_analysis().dataflow_analysis().ToString()); assignment->alias_analysis().dataflow_analysis().ToString());

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
@ -440,11 +439,6 @@ class BufferAssignment {
bool HaveDisjointSlices(const HloInstruction* hlo_a, bool HaveDisjointSlices(const HloInstruction* hlo_a,
const HloInstruction* hlo_b) const; const HloInstruction* hlo_b) const;
// Returns the underlying points-to analysis used for this assignment.
const TuplePointsToAnalysis& points_to_analysis() const {
return liveness_->points_to_analysis();
}
const HloDataflowAnalysis& dataflow_analysis() const { const HloDataflowAnalysis& dataflow_analysis() const {
return alias_analysis_->dataflow_analysis(); return alias_analysis_->dataflow_analysis();
} }
@ -452,7 +446,7 @@ class BufferAssignment {
HloAliasAnalysis& alias_analysis() const { return *alias_analysis_; } HloAliasAnalysis& alias_analysis() const { return *alias_analysis_; }
// Returns the BufferLiveness object used to construct this assignment. // Returns the BufferLiveness object used to construct this assignment.
const BufferLiveness& liveness() const { return *liveness_; } const HloOrdering& hlo_ordering() const { return *hlo_ordering_; }
string ToString() const; string ToString() const;
BufferAssignmentProto ToProto() const; BufferAssignmentProto ToProto() const;
@ -483,12 +477,12 @@ class BufferAssignment {
friend class BufferAssigner; friend class BufferAssigner;
BufferAssignment(const HloModule* module, BufferAssignment(const HloModule* module,
std::unique_ptr<BufferLiveness> liveness, std::unique_ptr<HloOrdering> hlo_ordering,
BufferValue::SizeFunction buffer_size, BufferValue::SizeFunction buffer_size,
LogicalBuffer::AlignmentFunction color_alignment, LogicalBuffer::AlignmentFunction color_alignment,
std::unique_ptr<HloAliasAnalysis> alias_analysis) std::unique_ptr<HloAliasAnalysis> alias_analysis)
: module_(module), : module_(module),
liveness_(std::move(liveness)), hlo_ordering_(std::move(hlo_ordering)),
buffer_size_(std::move(buffer_size)), buffer_size_(std::move(buffer_size)),
color_alignment_(std::move(color_alignment)), color_alignment_(std::move(color_alignment)),
alias_analysis_(std::move(alias_analysis)) {} alias_analysis_(std::move(alias_analysis)) {}
@ -540,7 +534,8 @@ class BufferAssignment {
allocation_index_for_value_; allocation_index_for_value_;
const HloModule* module_; const HloModule* module_;
const std::unique_ptr<BufferLiveness> liveness_;
const std::unique_ptr<HloOrdering> hlo_ordering_;
// Function which returns the buffer size for a given logical buffer (shape). // Function which returns the buffer size for a given logical buffer (shape).
BufferValue::SizeFunction buffer_size_; BufferValue::SizeFunction buffer_size_;

View File

@ -1,179 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Defines the data returned by the XLA buffer assignment packages.
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include <utility>
#include <vector>
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
/* static */
StatusOr<std::unique_ptr<BufferLiveness>> BufferLiveness::Run(
const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering) {
std::unique_ptr<BufferLiveness> liveness(
new BufferLiveness(module, std::move(hlo_ordering)));
TF_RETURN_IF_ERROR(liveness->Analyze());
return std::move(liveness);
}
Status BufferLiveness::Analyze() {
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_));
for (auto* computation : module_->computations()) {
if (computation->IsFusionComputation()) {
continue;
}
// Gather all instructions whose buffers might alias other instructions into
// the set aliased_buffers_. This includes those contained as a tuple
// element in other instruction's output.
for (const auto& instruction : computation->instructions()) {
for (const LogicalBuffer* aliased_buffer :
points_to_analysis_->GetPointsToSet(instruction)
.CreateFlattenedSet()) {
if (aliased_buffer->instruction() != instruction) {
aliased_buffers_.insert(aliased_buffer);
}
}
}
if (computation == module_->entry_computation()) {
const HloInstruction* root = computation->root_instruction();
maybe_live_out_buffers_ =
points_to_analysis_->GetPointsToSet(root).CreateFlattenedSet();
}
}
XLA_VLOG_LINES(3, ToString());
return Status::OK();
}
string BufferLiveness::ToString() const {
std::vector<string> pieces;
pieces.push_back(
absl::StrFormat("BufferLiveness(module=%s):", module_->name()));
pieces.push_back("HloOrdering:");
pieces.push_back(hlo_ordering_->ToString());
pieces.push_back("Aliased buffers:");
for (const LogicalBuffer* buffer : aliased_buffers_) {
pieces.push_back(absl::StrFormat(" %s", buffer->ToString()));
}
pieces.push_back("Live out buffers:");
for (const LogicalBuffer* buffer : maybe_live_out_buffers_) {
pieces.push_back(absl::StrFormat(" %s", buffer->ToString()));
}
return absl::StrJoin(pieces, "\n");
}
bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
const LogicalBuffer& b) const {
TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(a));
TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(b));
if (!hlo_ordering_->ExecutesBefore(a.instruction(), b.instruction())) {
return false;
}
for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
// Every user of 'a' must be a predecessor of 'b' or 'b' itself.
for (auto user : alias.instruction()->users()) {
if (points_to_analysis().DoesNotUseOperandBuffer(alias.instruction(),
alias.index(), user)) {
continue;
}
if (user != b.instruction() &&
!hlo_ordering_->ExecutesBefore(user, b.instruction())) {
return false;
}
}
// If the root instruction aliases the buffer 'a', the live range of 'a' is
// until the end of the computation and can never be strictly before another
// buffer nested in the same computation. This is needed to prevent the root
// instruction's buffers from being reused by later instructions even when
// the root is not the last instruction in the schedule.
if (alias.instruction()->parent()->root_instruction() ==
alias.instruction() &&
hlo_ordering_->call_graph().InstructionIsNestedIn(
b.instruction(), alias.instruction()->parent())) {
return false;
}
}
// If 'b' is a user of 'a' then the buffers interfere unless 'a.instruction'
// and 'b.instruction' emit the same shape/layout, and 'b.instruction' meets
// the qualifications specified in CanShareOperandBufferWithUser.
for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
if (b.instruction()->IsUserOf(alias.instruction()) &&
!points_to_analysis().CanShareOperandBufferWithUser(
alias.instruction(), alias.index(), b.instruction(), b.index())) {
return false;
}
}
return true;
}
namespace {
bool IsEntryParameter(const HloInstruction* instruction) {
const HloComputation* computation = instruction->parent();
return instruction->opcode() == HloOpcode::kParameter &&
computation == computation->parent()->entry_computation();
}
} // namespace
bool BufferLiveness::MayInterfere(const LogicalBuffer& a,
const LogicalBuffer& b) const {
// Parameters live at the entry of the computation, thus always interfere with
// all other instructions inside the computation executing before them in the
// ordering.
const HloInstruction* a_instruction = a.instruction();
const HloInstruction* b_instruction = b.instruction();
if (a_instruction->opcode() == HloOpcode::kParameter &&
hlo_ordering_->call_graph().InstructionIsNestedIn(
b_instruction, a_instruction->parent()) &&
hlo_ordering_->ExecutesBefore(b_instruction, a_instruction)) {
return true;
}
if (b_instruction->opcode() == HloOpcode::kParameter &&
hlo_ordering_->call_graph().InstructionIsNestedIn(
a_instruction, b_instruction->parent()) &&
hlo_ordering_->ExecutesBefore(a_instruction, b_instruction)) {
return true;
}
// Buffers without disjoint liveness may interfere.
return !live_range_strictly_before(a, b) && !live_range_strictly_before(b, a);
}
bool BufferLiveness::MaybeLiveOut(const LogicalBuffer& buffer) const {
// Verify that a buffer is actually defined at the given instruction/index
// (eg, its not an alias of another buffer such as occurs with a bitcast).
TF_CHECK_OK(points_to_analysis_->VerifyBuffer(buffer));
return maybe_live_out_buffers_.count(&buffer);
}
} // namespace xla

View File

@ -1,114 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_LIVENESS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_LIVENESS_H_
#include <memory>
#include <string>
#include <utility>
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
namespace xla {
// Class which computes liveness of the output buffers of HLOs and their
// interference.
class BufferLiveness {
public:
using Colorer = std::function<Status(const BufferLiveness& buffer_liveness)>;
// Constructs a buffer liveness object for the given module assuming the given
// HLO instruction ordering.
static StatusOr<std::unique_ptr<BufferLiveness>> Run(
const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering);
// Returns true if the live range of the buffer containing the output of 'a'
// may overlap with the live range of the buffer of 'b'. If instruction 'a'
// interferes with instruction 'b' then they cannot share the same buffer.
bool MayInterfere(const LogicalBuffer& a, const LogicalBuffer& b) const;
// Returns true if the buffer for the given instruction may be live out of the
// module. That is, the instruction's buffer may be included in the output of
// the entry computation.
bool MaybeLiveOut(const LogicalBuffer& buffer) const;
// Returns the complete set of buffers that may be live out of the module.
const PointsToSet::BufferSet& maybe_live_out_buffers() const {
return maybe_live_out_buffers_;
}
// Returns the underlying points-to analysis used for this liveness analysis.
const TuplePointsToAnalysis& points_to_analysis() const {
return *points_to_analysis_;
}
// Returns the underlying hlo ordering used for this liveness analysis.
const HloOrdering& hlo_ordering() const { return *hlo_ordering_; }
const HloModule& module() const { return *module_; }
string ToString() const;
static Colorer DefaultColorer() {
return [](const BufferLiveness& buffer_liveness) {
for (LogicalBuffer::Id id = 0;
id < buffer_liveness.points_to_analysis().num_logical_buffers();
id++) {
auto& buffer = buffer_liveness.points_to_analysis().logical_buffer(id);
buffer.set_color(LogicalBuffer::Color(0));
}
return Status::OK();
};
}
private:
explicit BufferLiveness(const HloModule* module,
std::unique_ptr<HloOrdering> hlo_ordering)
: module_(module), hlo_ordering_(std::move(hlo_ordering)) {}
// Perform buffer liveness analysis. This method must be called prior to
// MayInterfere or MaybeLiveOut.
Status Analyze();
// Returns true if the live range of the buffer of 'a' is strictly before the
// live range of the buffer of 'b' (they do not overlap).
bool live_range_strictly_before(const LogicalBuffer& a,
const LogicalBuffer& b) const;
const HloModule* module_;
std::unique_ptr<HloOrdering> hlo_ordering_;
// Set of LogicalBuffers which are aliased in the output of other
// instructions. For example, a LogicalBuffer which is inserted into a tuple
// is considered to be aliased and will be in this set.
absl::flat_hash_set<const LogicalBuffer*> aliased_buffers_;
// LogicalBuffers that may be live out of the entry computation.
PointsToSet::BufferSet maybe_live_out_buffers_;
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_LIVENESS_H_

View File

@ -1,934 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include <memory>
#include <string>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
class BufferLivenessTest : public HloTestBase {
protected:
// Returns the LogicalBuffer defined at the given instruction and
// index. CHECKs if no buffer is defined at that point.
const LogicalBuffer& GetBuffer(const BufferLiveness& liveness,
const HloInstruction* instruction,
const ShapeIndex& index) {
const auto& pointed_to = liveness.points_to_analysis()
.GetPointsToSet(instruction)
.element(index);
CHECK_EQ(1, pointed_to.size());
CHECK_EQ(instruction, pointed_to[0]->instruction());
CHECK(index == pointed_to[0]->index());
return *pointed_to[0];
}
// Returns true if the top-level buffers for instructions 'a' and 'b' may
// interfere. Precondition: 'a' and 'b' are array-shaped.
bool InstructionsMayInterfere(const BufferLiveness& liveness,
HloInstruction* a, HloInstruction* b) {
EXPECT_FALSE(a->shape().IsTuple());
EXPECT_FALSE(b->shape().IsTuple());
return liveness.MayInterfere(
GetBuffer(liveness, /*instruction=*/a, /*index=*/{}),
GetBuffer(liveness, /*instruction=*/b, /*index=*/{}));
}
// Returns true if the tuple elements at 'index' for instructions 'a' and 'b'
// may interfere. Precondition: 'a' and 'b' are tuple-shaped, with equal
// tuple element sub-shapes.
bool TupleElementsMayInterfere(const BufferLiveness& liveness,
HloInstruction* a, HloInstruction* b,
const ShapeIndex& index) {
// Check that top-level shapes are tuple and tuple element shapes are equal.
EXPECT_TRUE(a->shape().IsTuple());
EXPECT_TRUE(b->shape().IsTuple());
EXPECT_TRUE(
ShapeUtil::Compatible(ShapeUtil::GetSubshape(a->shape(), index),
ShapeUtil::GetSubshape(b->shape(), index)));
// Lookup PointsTo set for instructions 'a' and 'b'.
auto& points_to_analysis = liveness.points_to_analysis();
const auto& points_to_a =
points_to_analysis.GetPointsToSet(a).element(index);
const auto& points_to_b =
points_to_analysis.GetPointsToSet(b).element(index);
// Make sure PointsTo sets for 'a' and 'b' are unambiguous.
EXPECT_EQ(1, points_to_a.size());
EXPECT_EQ(points_to_a.size(), points_to_b.size());
// Check interference.
return liveness.MayInterfere(*points_to_a[0], *points_to_b[0]);
}
// Returns true if the top-level buffers for the given instruction maybe
// liveout of the entry computation.
// Precondition: instruction is array-shaped.
bool InstructionMaybeLiveOut(const BufferLiveness& liveness,
HloInstruction* instruction) {
return liveness.MaybeLiveOut(
GetBuffer(liveness, instruction, /*index=*/{}));
}
std::unique_ptr<HloComputation> BuildDummyComputation() {
auto builder = HloComputation::Builder(TestName() + "_dummy");
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
return builder.Build();
}
const Shape vec_ = ShapeUtil::MakeShape(xla::F32, {42});
};
TEST_F(BufferLivenessTest, ElementwiseChain) {
// A simple chain of elementwise operations. No buffers should interfere.
//
// param --> negate -> exp -> log
//
auto builder = HloComputation::Builder(TestName());
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate));
auto log = builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
auto liveness =
BufferLiveness::Run(
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, log));
// No buffers should interfere.
EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, exp));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, log));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, log));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, exp));
// Buffers should interfere with itself.
EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, exp));
// Only log is live out.
EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param));
EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, negate));
EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, exp));
EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, log));
}
TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) {
// Two entry params, which interfere with each other.
//
// param0 --> negate ---------------\
// param1 --> exp --> add
auto builder = HloComputation::Builder(TestName());
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, vec_, "param0"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, vec_, "param1"));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param0));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param1));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
auto module = CreateNewVerifiedModule();
HloComputation* entry = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(entry, {param0, negate, param1, exp, add});
auto liveness =
BufferLiveness::Run(module.get(),
absl::make_unique<SequentialHloOrdering>(schedule))
.ConsumeValueOrDie();
// Entry parameters interfere as if they are defined simultaneously at
// the very beginning.
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param0, param1));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, exp));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, add));
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, param0));
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, exp));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, add));
// Negate and exp still interfere.
EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
// But {negate, add} and {exp, add} don't interfere.
EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
}
TEST_F(BufferLivenessTest, EmbeddedComputationParameters) {
absl::string_view hlo_string = R"(
HloModule EmbeddedComputationParameters, is_scheduled=true
%EmbeddedComputationParameters_embedded (embedded_param0: f32[42], embedded_param1: f32[42]) -> (f32[42], f32[42]) {
%embedded_param0 = f32[42]{0} parameter(0)
%log = f32[42]{0} log(f32[42]{0} %embedded_param0)
%add = f32[42]{0} add(f32[42]{0} %log, f32[42]{0} %log)
%embedded_param1 = f32[42]{0} parameter(1)
ROOT %tuple = (f32[42]{0}, f32[42]{0}) tuple(f32[42]{0} %add, f32[42]{0} %embedded_param1)
}
ENTRY %EmbeddedComputationParameters (param0: f32[42], param1: f32[42]) -> (f32[42], f32[42]) {
%param0 = f32[42]{0} parameter(0)
%param1 = f32[42]{0} parameter(1)
ROOT %call = (f32[42]{0}, f32[42]{0}) call(f32[42]{0} %param0, f32[42]{0} %param1), to_apply=%EmbeddedComputationParameters_embedded
}
)";
HloModuleConfig hlo_config;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string, hlo_config));
auto liveness =
BufferLiveness::Run(
module.get(),
absl::make_unique<SequentialHloOrdering>(module->schedule()))
.ConsumeValueOrDie();
auto embedded_log = FindInstruction(module.get(), "log");
auto embedded_param0 = FindInstruction(module.get(), "embedded_param0");
auto embedded_param1 = FindInstruction(module.get(), "embedded_param1");
auto param0 = FindInstruction(module.get(), "param0");
auto param1 = FindInstruction(module.get(), "param1");
// Parameters should interfere with other instructions inside the computation.
EXPECT_TRUE(
InstructionsMayInterfere(*liveness, embedded_log, embedded_param1));
EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_log, param0));
EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_log, param1));
EXPECT_TRUE(
InstructionsMayInterfere(*liveness, embedded_param0, embedded_param1));
}
TEST_F(BufferLivenessTest, InterferenceWithOuterRoot) {
absl::string_view hlo_string = R"(
HloModule InterferenceWithOuterRoot, is_scheduled=true
Emmbedded (embedded_param: f32[42]) -> f32[42] {
embedded_param = f32[42]{0} parameter(0)
multiply = f32[42]{0} multiply(embedded_param, embedded_param)
ROOT log = f32[42]{0} log(multiply)
}
ENTRY InterferenceWithOuterRoot {
param = f32[4096,4096]{1,0} parameter(0)
ROOT add = f32[4096,4096]{1,0} add(param, param)
call = f32[42]{0} call(param), to_apply=Emmbedded
}
)";
HloModuleConfig hlo_config;
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(hlo_string, hlo_config));
auto liveness =
BufferLiveness::Run(
module.get(),
absl::make_unique<SequentialHloOrdering>(module->schedule()))
.ConsumeValueOrDie();
auto multiply = FindInstruction(module.get(), "multiply");
auto add = FindInstruction(module.get(), "add");
EXPECT_TRUE(InstructionsMayInterfere(*liveness, multiply, add));
}
TEST_F(BufferLivenessTest, NonElementwiseOperand) {
// A chain of operations with two elementwise and one non-elementwise. The
// elementwise op should not interfere with its operand, while the
// non-elementwise op should interfere. Entry params always interfere.
//
// param --> exp -> negate -> reverse
//
auto builder = HloComputation::Builder(TestName());
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, exp));
auto reverse =
builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0}));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
auto liveness =
BufferLiveness::Run(
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, reverse));
// Negate is elementwise, so doesn't interfere with its operand.
// Reverse is non-elementwise, so does interfere with its operand.
EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate));
EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, reverse));
}
TEST_F(BufferLivenessTest, OverlappedBuffers) {
// Verify simultaneously live buffers interfere (exp and negate).
//
// param --> negate -> add
// \---> exp -----/
//
auto builder = HloComputation::Builder(TestName());
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
auto liveness =
BufferLiveness::Run(
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, exp));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
// Negate and exp interfere with each other, but not with add.
EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
}
TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) {
// Identical to the test OverlappedBuffer but using a sequential ordering of
// HLO instructions.
//
// param --> negate -> add
// \---> exp -----/
//
// Sequential order:
// param, negate, exp, add
//
// Liveness is identical to the DependencyHloOrdering.
auto builder = HloComputation::Builder(TestName());
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {param, negate, exp, add});
auto liveness =
BufferLiveness::Run(module.get(),
absl::make_unique<SequentialHloOrdering>(schedule))
.ConsumeValueOrDie();
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
// Negate and exp interfere with each other, but not with add.
EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
}
TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) {
// Tests that when the root instruction is not the last instruction in the
// schedule, the live range of its buffers interfere with the buffers of the
// later instructions.
//
// Two sets of independent instructions are executed in the computation.
// param --> add (root)
// recv --> recv-done --> send --> send-done
//
// Sequential order:
// param, add (root), recv, recv-done, send, send-done
auto builder = HloComputation::Builder(TestName());
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param));
auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto recv = builder.AddInstruction(
HloInstruction::CreateRecv(vec_, token, /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
auto send = builder.AddInstruction(
HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build(add));
HloSchedule schedule(module.get());
schedule.set_sequence(computation,
{param, add, token, recv, recv_done, send, send_done});
TF_ASSERT_OK(schedule.Verify());
auto liveness =
BufferLiveness::Run(module.get(),
absl::make_unique<SequentialHloOrdering>(schedule))
.ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
// Check the root instruction (add) buffer interferes with the recv buffer.
EXPECT_TRUE(
liveness->MayInterfere(GetBuffer(*liveness, add, /*index=*/{}),
GetBuffer(*liveness, recv, /*index=*/{0})));
}
TEST_F(BufferLivenessTest, TupleLiveOut) {
// Verify MaybeLiveOut with nested tuples. Result of computation looks like:
//
// Tuple({Tuple({Negate(Param)}, Exp(Negate(Param)))})
//
// All values should be live out except Param.
auto builder = HloComputation::Builder(TestName());
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
auto inner_tuple =
builder.AddInstruction(HloInstruction::CreateTuple({negate}));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate));
auto outer_tuple =
builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp}));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
auto liveness =
BufferLiveness::Run(
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// All buffers should be live out except the param
EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param));
EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, negate));
EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, inner_tuple));
EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, exp));
EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, outer_tuple));
}
// bitcast liveout.
TEST_F(BufferLivenessTest, EmbeddedComputation) {
// Test MaybeLiveOut and MayInterfere for embedded computation.
auto module = CreateNewVerifiedModule();
auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
auto embedded_param = embedded_builder.AddInstruction(
HloInstruction::CreateParameter(0, vec_, "embedded_param"));
auto embedded_log = embedded_builder.AddInstruction(
HloInstruction::CreateUnary(vec_, HloOpcode::kLog, embedded_param));
auto embedded_computation =
module->AddEmbeddedComputation(embedded_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
auto call = builder.AddInstruction(
HloInstruction::CreateCall(vec_, {param}, embedded_computation));
module->AddEntryComputation(builder.Build());
auto liveness =
BufferLiveness::Run(
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Buffers in different computations should always interfere.
EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_log, call));
EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_param, param));
EXPECT_FALSE(
InstructionsMayInterfere(*liveness, embedded_param, embedded_log));
// The only buffers for which MaybeLiveOut == true are those live out
// of the entry computation. Buffers live out of embedded computations should
// return false for this method.
EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, embedded_log));
EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, call));
}
TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
// Verify non top-level elements of a nested tuple constant are properly
// marked as liveout. Computation:
//
// GetTupleElement(0, TupleConstant({{0, 1}, {3}})
//
// Only the array buffers containing 0 and 1 are liveout of the
// computation. The buffer containing {0, 1} is copied by GetTupleElement, and
// the buffers containing {3} and 3 are dead.
auto builder = HloComputation::Builder(TestName());
Literal elements0[] = {LiteralUtil::CreateR0<int64>(0),
LiteralUtil::CreateR0<int64>(1)};
auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]});
Literal element1 = LiteralUtil::CreateR0<int64>(3);
auto inner_tuple1 = LiteralUtil::MakeTuple({&element1});
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1})));
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
inner_tuple0.shape(), tuple_constant, 0));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
auto liveness =
BufferLiveness::Run(
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Only the element buffers of the tuple constant which are pointed to by
// the GetTupleElement instruction should be liveout.
EXPECT_FALSE(liveness->MaybeLiveOut(
GetBuffer(*liveness, tuple_constant, /*index=*/{})));
EXPECT_TRUE(liveness->MaybeLiveOut(
GetBuffer(*liveness, tuple_constant, /*index=*/{0})));
EXPECT_TRUE(liveness->MaybeLiveOut(
GetBuffer(*liveness, tuple_constant, /*index=*/{0, 0})));
EXPECT_TRUE(liveness->MaybeLiveOut(
GetBuffer(*liveness, tuple_constant, /*index=*/{0, 1})));
EXPECT_FALSE(liveness->MaybeLiveOut(
GetBuffer(*liveness, tuple_constant, /*index=*/{1})));
EXPECT_FALSE(liveness->MaybeLiveOut(
GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0})));
EXPECT_FALSE(liveness->MaybeLiveOut(
GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0})));
}
TEST_F(BufferLivenessTest, IndependentTupleElements) {
auto builder = HloComputation::Builder(TestName());
// Create param0 Tuple.
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0,
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(S32, {4})}),
"param0"));
// Create independent computations for each tuple elememt.
// Tuple element0 computation:
// Add(GetTupleElement(tuple_param0, 0), const0)
auto tuple_element0_shape =
ShapeUtil::GetSubshape(tuple_param0->shape(), {0});
auto tuple_element0 =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple_element0_shape, tuple_param0, 0));
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
// Tuple element1 computation:
// Add(GetTupleElement(tuple_param0, 1), const1)
auto tuple_element1_shape =
ShapeUtil::GetSubshape(tuple_param0->shape(), {1});
auto tuple_element1 =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple_element1_shape, tuple_param0, 1));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1));
// Create output tuple.
auto tuple_root =
builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
auto module = CreateNewUnverifiedModule();
module->AddEntryComputation(BuildDummyComputation());
module->AddEmbeddedComputation(builder.Build());
auto liveness =
BufferLiveness::Run(
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// We compare tuple element pairs that are input/output to the computation:
// 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0')
// 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1')
// Tuple output element 'add0' does not depend on input 'tuple_element1'.
// Tuple output element 'add1' does not depend on input 'tuple_element0'.
// Both element pair does not interfere, because there is no other dependency
// on the pairs tuple input element, and so liveness can compute that all
// users of the input tuple element execute before the associated output
// tuple element.
EXPECT_FALSE(
TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0}));
EXPECT_FALSE(
TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}));
}
TEST_F(BufferLivenessTest, DependentTupleElements) {
auto builder = HloComputation::Builder(TestName());
// Create param0 Tuple.
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0,
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})}),
"param0"));
// Create dependent computations for each tuple elememt.
// Tuple element0 computation:
// Add(GetTupleElement(tuple_param0, 0), const0)
auto tuple_element0_shape =
ShapeUtil::GetSubshape(tuple_param0->shape(), {0});
auto tuple_element0 =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple_element0_shape, tuple_param0, 0));
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
// Tuple element1 computation:
// Add(GetTupleElement(tuple_param0, 0), GetTupleElement(tuple_param0, 1))
auto tuple_element1_shape =
ShapeUtil::GetSubshape(tuple_param0->shape(), {1});
auto tuple_element1 =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple_element1_shape, tuple_param0, 1));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
tuple_element1_shape, HloOpcode::kAdd, tuple_element0, tuple_element1));
// Create output tuple.
auto tuple_root =
builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(BuildDummyComputation());
module->AddEmbeddedComputation(builder.Build());
auto liveness =
BufferLiveness::Run(
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// We compare tuple element pairs that are input/output to the computation:
// 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0')
// 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1')
// The first tuple element pair output 'add0', has no dependency on second
// tuple element pairs input 'tuple_element1'.
// The second tuple element pair output 'add1', has a dependency on first
// tuple element pairs input 'tuple_element0'.
// The first tuple element pair does interfere, because liveness cannot
// compute that all references to 'tuple_element0' are executed before 'add0'
// (because of the depenency of 'add1' on 'tuple_element0').
EXPECT_TRUE(
TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0}));
// The second tuple element pair does not interfere, because there is no
// other dependency on 'tuple_element1', and so liveness can compute that
// all users execute before 'add1'.
EXPECT_FALSE(
TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}));
}
class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
protected:
// Builds and runs a computation (see test case computation graphs below).
std::unique_ptr<VerifiedHloModule> BuildModule(
const bool update_uses_tuple_element1, const bool fuse_gte0) {
auto builder = HloComputation::Builder(TestName());
// Create param0 Tuple.
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
Shape update_shape = ShapeUtil::MakeShape(F32, {3});
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
HloInstruction* slice = nullptr;
if (update_uses_tuple_element1) {
// Create a slice instruction as an additional user of 'gte1'.
slice = builder.AddInstruction(
HloInstruction::CreateSlice(update_shape, gte1, {0}, {3}, {1}));
update = builder.AddInstruction(HloInstruction::CreateBinary(
update_shape, HloOpcode::kAdd, update, slice));
}
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
auto starts = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, {starts}));
// Create output tuple.
builder.AddInstruction(
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
// Build module and get reference to entry computation.
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
auto* computation = module->entry_computation();
// Create fusion instruction based on number of tuple element 1 users.
if (update_uses_tuple_element1) {
computation->CreateFusionInstruction(
{dynamic_update_slice, starts, update, CHECK_NOTNULL(slice), gte1},
HloInstruction::FusionKind::kLoop);
} else {
computation->CreateFusionInstruction(
{dynamic_update_slice, starts, update, gte1},
HloInstruction::FusionKind::kLoop);
}
// Create fusion instruction for tuple element 0 (if requested).
if (fuse_gte0) {
computation->CreateFusionInstruction({gte0},
HloInstruction::FusionKind::kLoop);
}
return module;
}
// Returns whether buffer interference is detected between tuple-shaped
// parameter and root instructions at tuple element 1.
bool Run(const bool update_uses_tuple_element1,
const bool fuse_gte0 = false) {
auto module = BuildModule(update_uses_tuple_element1, fuse_gte0);
// Run BufferLiveness on 'module'.
auto liveness = BufferLiveness::Run(
module.get(),
absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Return whether or not buffers interference is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
auto tuple_param0 = FindInstruction(module.get(), "param0");
auto tuple_root = module->entry_computation()->root_instruction();
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
}
bool RunWithHloDataflowAnalysis(const bool update_uses_tuple_element1,
const bool fuse_gte0 = false) {
auto module = BuildModule(update_uses_tuple_element1, fuse_gte0);
// Run BufferLiveness on 'module'.
auto dataflow = HloDataflowAnalysis::Run(*module).ConsumeValueOrDie();
auto hlo_ordering = absl::make_unique<DependencyHloOrdering>(module.get());
// Return whether or not buffers interference is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
auto tuple_param0 = FindInstruction(module.get(), "param0");
auto tuple_root = module->entry_computation()->root_instruction();
return hlo_ordering->MayInterfere(
dataflow->GetUniqueValueAt(tuple_param0, {1}),
dataflow->GetUniqueValueAt(tuple_root, {1}), *dataflow);
}
};
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
// do not overlap with the following computation:
//
// Param0
// / \
// GTE(0) Fusion -----------> FusionParam
// | | |
// | | GTE(1) Const Const
// | | \ | /
// | | DynamicUpdateSlice // fused root
// \ /
// Tuple // computation root
//
TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) {
EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false));
EXPECT_FALSE(
RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false));
}
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases
// 'fusion1') do not overlap in the presence of another fusion instruction
// (which is a user of 'param0' at a different tuple index).
// BufferLiveness should detect no uses of Param0 at index {1} in Fusion0
// (because Fusion0 only uses Param0 at index {0}).
//
// Param0
// / \
// FusionParam <----- Fusion0 Fusion1 ------> FusionParam
// | | | |
// GTE(0) | | GTE(1) Const Const
// | | \ | /
// \ / DynamicUpdateSlice
// Tuple
//
TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) {
EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true));
EXPECT_FALSE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false,
/*fuse_gte0=*/true));
}
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
// do overlap because GTE(1) has two users:
// 1) DynamicUpdateSlice at operand 0.
// 2) Slice at operand 0.
//
// Param0
// / \ Const
// / \ /
// GTE(0) Fusion -----------> FusionParam FusionParam
// | | | |
// | | GTE(1) /
// | | | \ /
// | | | Slice /
// | | | \ /
// | | | Add Const
// | | | | |
// | | DynamicUpdateSlice // fused root
// \ /
// Tuple // computation root
//
TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) {
EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true));
EXPECT_TRUE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/true));
}
class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
protected:
// Builds and runs a computation (see test case computation graphs below).
// Runs BufferLiveness on this computation.
// Returns whether buffer interference is detected between tuple-shaped
// parameter and root instructions at tuple element 1.
bool Run(const bool tuple_element1_has_two_uses) {
auto builder = HloComputation::Builder(TestName());
// Create param0 Tuple.
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
Shape update_shape = ShapeUtil::MakeShape(F32, {3});
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
if (tuple_element1_has_two_uses) {
// Add 'gte0' and 'gte1' to create another user of 'gte1'.
gte0 = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape, HloOpcode::kAdd, gte0, gte1));
}
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
auto starts = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, {starts}));
// Create output tuple.
auto tuple_root = builder.AddInstruction(
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
// Build module and get reference to entry computation.
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(BuildDummyComputation());
module->AddEmbeddedComputation(builder.Build());
// Run BufferLiveness on 'module'.
auto liveness = BufferLiveness::Run(
module.get(),
absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Return whether or not buffers interference is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
}
};
// Tests that live ranges of buffers Param0[1] and Tuple[1] do not overlap in
// the following computation (because DynamicUpdateSlice (at operand 0) is the
// unique user):
//
// Parameter0
// | |
// GTE(0) GTE(1) Const Const
// | \ | /
// | DynamicUpdateSlice
// \ /
// Tuple
//
TEST_F(DynamicUpdateSliceLivenessTest, NoInterference) {
EXPECT_FALSE(Run(/*tuple_element1_has_two_uses=*/false));
}
// Tests that live ranges of buffers Param0[1] and Tuple[1] do overlap because
// GTE(1) has two users:
// 1) DynamicUpdateSlice at operand 0.
// 2) Add at operand 1.
//
// Parameter0
// | |
// GTE(0) GTE(1)
// | / |
// | / |
// Add | Const Const
// | | | |
// | DynamicUpdateSlice
// \ /
// Tuple
//
TEST_F(DynamicUpdateSliceLivenessTest, WithInterference) {
EXPECT_TRUE(Run(/*tuple_element1_has_two_uses=*/true));
}
} // namespace
} // namespace xla

Some files were not shown because too many files have changed in this diff Show More