Merge branch 'master' into resolve-exhaustive-test-compile-issue
This commit is contained in:
commit
20f510ceae
1
.bazelrc
1
.bazelrc
@ -40,6 +40,7 @@ build:mkl -c opt
|
|||||||
# This config option is used to enable MKL-DNN open source library only,
|
# This config option is used to enable MKL-DNN open source library only,
|
||||||
# without depending on MKL binary version.
|
# without depending on MKL binary version.
|
||||||
build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
|
build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
|
||||||
|
build:mkl_open_source_only --define=build_with_mkl_dnn_v1_only=true
|
||||||
build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
|
build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
|
||||||
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
|
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0
|
||||||
|
|
||||||
|
271
RELEASE.md
271
RELEASE.md
@ -1,3 +1,274 @@
|
|||||||
|
# Release 1.14.0
|
||||||
|
|
||||||
|
## Major Features and Improvements
|
||||||
|
|
||||||
|
* This is the first 1.x release containing the compat.v2 module. This module
|
||||||
|
is required to allow libraries to publish code which works in both 1.x and
|
||||||
|
2.x. After this release, no backwards incompatible changes are allowed in
|
||||||
|
the 2.0 Python API.
|
||||||
|
* Turn on MKL-DNN contraction kernels by default. MKL-DNN dynamically
|
||||||
|
dispatches the best kernel implementation based on CPU vector architecture.
|
||||||
|
To disable them, build with --define=tensorflow_mkldnn_contraction_kernel=0.
|
||||||
|
|
||||||
|
## Behavioral changes
|
||||||
|
|
||||||
|
* Set default loss reduction as `AUTO` for improving reliability of loss
|
||||||
|
scaling with distribution strategy and custom training loops. `AUTO`
|
||||||
|
indicates that the reduction option will be determined by the usage context.
|
||||||
|
For almost all cases this defaults to `SUM_OVER_BATCH_SIZE`. When used in
|
||||||
|
distribution strategy scope, outside of built-in training loops such as
|
||||||
|
`tf.keras` `compile` and `fit`, we expect reduction value to be 'None' or
|
||||||
|
'SUM'. Using other values will raise an error.
|
||||||
|
* Wraps losses passed to the `compile` API (strings and v1 losses) which are
|
||||||
|
not instances of v2 `Loss` class in `LossWrapper` class. => All losses will
|
||||||
|
now use `SUM_OVER_BATCH_SIZE` reduction as default.
|
||||||
|
* Disable `run_eagerly` and distribution strategy if there are symbolic
|
||||||
|
tensors added to the model using `add_metric` or `add_loss`.
|
||||||
|
* tf.linspace(start, stop, num) now always uses "stop" as last value (for
|
||||||
|
num > 1)
|
||||||
|
* `ResourceVariable` and `Variable` no longer accepts `constraint` in the
|
||||||
|
constructor, nor expose it as a @property.
|
||||||
|
* The behavior of tf.gather is now correct when axis=None and batch_dims<0.
|
||||||
|
* Only create a GCS directory object if the object does not already exist.
|
||||||
|
* In `map_vectorization` optimization, reduce the degree of parallelism in the
|
||||||
|
vectorized map node.
|
||||||
|
* Bug fix: loss and gradients should now more reliably be correctly scaled
|
||||||
|
w.r.t. the global batch size when using a tf.distribute.Strategy.
|
||||||
|
* Updating cosine similarity loss - removed the negate sign from cosine
|
||||||
|
similarity.
|
||||||
|
* DType is no longer convertible to an int. Use dtype.as_datatype_enum instead
|
||||||
|
of int(dtype) to get the same result.
|
||||||
|
* Changed default for gradient accumulation for TPU embeddings to true.
|
||||||
|
* Callbacks now log values in eager mode when a deferred build model is used.
|
||||||
|
* Transitive dependencies on :pooling_ops were removed. Some users may need to
|
||||||
|
add explicit dependencies on :pooling_ops if they reference the operators
|
||||||
|
from that library.
|
||||||
|
|
||||||
|
## Bug Fixes and Other Changes
|
||||||
|
|
||||||
|
* Documentation
|
||||||
|
* Deprecations and Symbol renames.
|
||||||
|
* Remove unused StringViewVariantWrapper
|
||||||
|
* Delete unused Fingerprint64Map op registration
|
||||||
|
* SignatureDef util functions have been deprecated.
|
||||||
|
* Renamed tf.image functions to remove duplicate "image" where it is
|
||||||
|
redundant.
|
||||||
|
* tf.keras.experimental.export renamed to
|
||||||
|
tf.keras.experimental.export_saved_model
|
||||||
|
* Standardize the LayerNormalization API by replacing the args `norm_axis`
|
||||||
|
and `params_axis` with `axis`.
|
||||||
|
* Tensor::UnsafeCopyFromInternal deprecated in favor Tensor::BitcastFrom
|
||||||
|
* Keras & Python API
|
||||||
|
* Add v2 module aliases for:
|
||||||
|
* tf.initializers => tf.keras.initializers
|
||||||
|
* tf.losses => tf.keras.losses & tf.metrics => tf.keras.metrics
|
||||||
|
* tf.optimizers => tf.keras.optimizers
|
||||||
|
* Add tf.keras.layers.AbstractRNNCell as the preferred implementation of
|
||||||
|
RNN cell for TF v2. User can use it to implement RNN cell with custom
|
||||||
|
behavior.
|
||||||
|
* Adding `clear_losses` API to be able to clear losses at the end of
|
||||||
|
forward pass in a custom training loop in eager.
|
||||||
|
* Add support for passing list of lists to the `metrics` param in Keras
|
||||||
|
`compile`.
|
||||||
|
* Added top-k to precision and recall to keras metrics.
|
||||||
|
* Adding public APIs for `cumsum` and `cumprod` keras backend functions.
|
||||||
|
* Fix: model.add_loss(symbolic_tensor) should work in ambient eager.
|
||||||
|
* Add name argument to tf.string_split and tf.strings_split
|
||||||
|
* Minor change to SavedModels exported from Keras using
|
||||||
|
tf.keras.experimental.export. (SignatureDef key for evaluation mode is
|
||||||
|
now "eval" instead of "test"). This will be reverted back to "test" in
|
||||||
|
the near future.
|
||||||
|
* Updates binary cross entropy logic in Keras when input is probabilities.
|
||||||
|
Instead of converting probabilities to logits, we are using the cross
|
||||||
|
entropy formula for probabilities.
|
||||||
|
* Raw TensorFlow functions can now be used in conjunction with the Keras
|
||||||
|
Functional API during model creation. This obviates the need for users
|
||||||
|
to create Lambda layers in most cases when using the Functional API.
|
||||||
|
Like Lambda layers, TensorFlow functions that result in Variable
|
||||||
|
creation or assign ops are not supported.
|
||||||
|
* Keras training and validation curves are shown on the same plot.
|
||||||
|
* Introduce `dynamic` constructor argument in Layer and Model, which
|
||||||
|
should be set to True when using imperative control flow in the `call`
|
||||||
|
method.
|
||||||
|
* Removing of dtype in the constructor of initializers and partition_info
|
||||||
|
in call.
|
||||||
|
* New ops and improved op functionality
|
||||||
|
* Add OpKernels for some stateless maps
|
||||||
|
* Add v2 APIs for AUCCurve and AUCSummationMethod
|
||||||
|
enums. #tf-metrics-convergence
|
||||||
|
* Add tf.math.nextafter op.
|
||||||
|
* Add CompositeTensor base class.
|
||||||
|
* Add tf.linalg.tridiagonal_solve op.
|
||||||
|
* Add opkernel templates for common table operations.
|
||||||
|
* Added support for TFLite in TensorFlow 2.0.
|
||||||
|
* Adds summary trace API for collecting graph and profile information.
|
||||||
|
* Add batch_dims argument to tf.gather.
|
||||||
|
* Add support for `add_metric` in the graph function mode.
|
||||||
|
* Add C++ Gradient for BatchMatMulV2.
|
||||||
|
* Added tf.random.binomial
|
||||||
|
* Added gradient for SparseToDense op.
|
||||||
|
* Add legacy string flat hash map op kernels
|
||||||
|
* Add a ragged size op and register it to the op dispatcher
|
||||||
|
* Add broadcasting support to tf.matmul.
|
||||||
|
* Add ellipsis (...) support for tf.einsum()
|
||||||
|
* Added LinearOperator.adjoint and LinearOperator.H (alias).
|
||||||
|
* Added GPU implementation of tf.linalg.tridiagonal_solve.
|
||||||
|
* Added strings.byte_split
|
||||||
|
* Add RaggedTensor.placeholder()
|
||||||
|
* Add a new "result_type" parameter to tf.strings.split
|
||||||
|
* `add_update` can now be passed a zero-arg callable in order to support
|
||||||
|
turning off the update when setting `trainable=False` on a Layer of a
|
||||||
|
Model compiled with `run_eagerly=True`.
|
||||||
|
* Add variant wrapper for absl::string_view
|
||||||
|
* Add expand_composites argument to all nest.* methods.
|
||||||
|
* Add pfor converter for Squeeze.
|
||||||
|
* Bug fix for tf.tile gradient
|
||||||
|
* Expose CriticalSection in core as tf.CriticalSection.
|
||||||
|
* Update Fingerprint64Map to use aliases
|
||||||
|
* ResourceVariable support for gather_nd.
|
||||||
|
* ResourceVariable's gather op supports batch dimensions.
|
||||||
|
* Variadic reduce is supported on CPU
|
||||||
|
* Extend tf.function with basic support for CompositeTensors arguments
|
||||||
|
(such as SparseTensor and RaggedTensor).
|
||||||
|
* Add templates and interfaces for creating lookup tables
|
||||||
|
* Post-training quantization tool supports quantizing weights shared by
|
||||||
|
multiple operations. The models made with versions of this tool will use
|
||||||
|
INT8 types for weights and will only be executable interpreters from
|
||||||
|
this version onwards.
|
||||||
|
* Malformed gif images could result in an access out of bounds in the
|
||||||
|
color palette of the frame. This has been fixed now
|
||||||
|
* image.resize now considers proper pixel centers and has new kernels
|
||||||
|
(incl. anti-aliasing).
|
||||||
|
* Performance
|
||||||
|
* Turn on MKL-DNN contraction kernels by default. MKL-DNN dynamically
|
||||||
|
dispatches the best kernel implementation based on CPU vector
|
||||||
|
architecture. To disable them, build with
|
||||||
|
--define=tensorflow_mkldnn_contraction_kernel=0.
|
||||||
|
* Support for multi-host ncclAllReduce in Distribution Strategy.
|
||||||
|
* Expose a flag that allows the number of threads to vary across Python
|
||||||
|
benchmarks.
|
||||||
|
* TensorFlow 2.0 Development
|
||||||
|
* Add v2 sparse categorical crossentropy metric.
|
||||||
|
* Allow non-Tensors through v2 losses.
|
||||||
|
* Add UnifiedGRU as the new GRU implementation for tf2.0. Change the
|
||||||
|
default recurrent activation function for GRU from 'hard_sigmoid' to
|
||||||
|
'sigmoid', and 'reset_after' to True in 2.0. Historically recurrent
|
||||||
|
activation is 'hard_sigmoid' since it is fast than 'sigmoid'. With new
|
||||||
|
unified backend between CPU and GPU mode, since the CuDNN kernel is
|
||||||
|
using sigmoid, we change the default for CPU mode to sigmoid as well.
|
||||||
|
With that, the default GRU will be compatible with both CPU and GPU
|
||||||
|
kernel. This will enable user with GPU to use CuDNN kernel by default
|
||||||
|
and get a 10x performance boost in training. Note that this is
|
||||||
|
checkpoint breaking change. If user want to use their 1.x pre-trained
|
||||||
|
checkpoint, please construct the layer with
|
||||||
|
GRU(recurrent_activation='hard_sigmoid', reset_after=False) to fallback
|
||||||
|
to 1.x behavior.
|
||||||
|
* TF 2.0 - Update metric name to always reflect what the user has given in
|
||||||
|
compile. Affects following cases 1. When name is given as
|
||||||
|
'accuracy'/'crossentropy' 2. When an aliased function name is used eg.
|
||||||
|
'mse' 3. Removing the `weighted` prefix from weighted metric names.
|
||||||
|
* Begin adding Go wrapper for C Eager API
|
||||||
|
* image.resize in 2.0 now supports gradients for the new resize kernels.
|
||||||
|
* removed tf.string_split from v2 API
|
||||||
|
* Expose tf.contrib.proto.* ops in tf.io (they will exist in TF2)
|
||||||
|
* "Updates the TFLiteConverter API in 2.0. Changes from_concrete_function
|
||||||
|
to from_concrete_functions."
|
||||||
|
* Enable tf.distribute.experimental.MultiWorkerMirroredStrategy working in
|
||||||
|
eager mode.
|
||||||
|
* Support both binary and -1/1 label input in v2 hinge and squared hinge
|
||||||
|
losses.
|
||||||
|
* TensorFlow Lite
|
||||||
|
* "Adds support for tflite_convert in 2.0."
|
||||||
|
* "Remove lite.OpHint, lite.experimental, and lite.constant from 2.0 API."
|
||||||
|
* tf.contrib
|
||||||
|
* Added Neural Turing Implementation as described in
|
||||||
|
https://arxiv.org/abs/1807.08518.
|
||||||
|
* Remove tf.contrib.timeseries dependency on TF distributions.
|
||||||
|
* tf.data
|
||||||
|
* Add num_parallel_reads and passing in a Dataset containing filenames
|
||||||
|
into TextLineDataset and FixedLengthRecordDataset
|
||||||
|
* Going forward we operate in TF 2.0, this change is part of the effort to
|
||||||
|
slowly converting XYZDataset to DatasetV2 type which is the official
|
||||||
|
version going to be used in TF 2.0 and motivated by some compatibility
|
||||||
|
issue found, _BigtableXYZDataset (of type DatasetV2) does not implement
|
||||||
|
the _as_variant_tensor() of DatasetV1, when moving contrib.bigtable to
|
||||||
|
tensorflow_io. Converting into DatasetV2 removes the overheads to
|
||||||
|
maintain V1 while we are moving into TF 2.0.
|
||||||
|
* Add dataset ops to the graph (or create kernels in Eager execution)
|
||||||
|
during the python Dataset object creation instead doing it during
|
||||||
|
Iterator creation time.
|
||||||
|
* Add support for TensorArrays to tf.data Dataset.
|
||||||
|
* Switching tf.data functions to use `defun`, providing an escape hatch to
|
||||||
|
continue using the legacy `Defun`.
|
||||||
|
* Toolchains
|
||||||
|
* CUDNN_INSTALL_PATH, TENSORRT_INSTALL_PATH, NCCL_INSTALL_PATH,
|
||||||
|
NCCL_HDR_PATH are deprecated. Use TF_CUDA_PATHS instead which supports a
|
||||||
|
comma-separated list of base paths that are searched to find CUDA
|
||||||
|
libraries and headers.
|
||||||
|
* TF code now resides in `tensorflow_core` and `tensorflow` is just a
|
||||||
|
virtual pip package. No code changes are needed for projects using
|
||||||
|
TensorFlow, the change is transparent
|
||||||
|
* XLA
|
||||||
|
* XLA HLO graphs can be inspected with interactive_graphviz tool now.
|
||||||
|
* Estimator
|
||||||
|
* Use tf.compat.v1.estimator.inputs instead of tf.estimator.inputs
|
||||||
|
* Replace contrib references with tf.estimator.experimental.* for apis in
|
||||||
|
early_stopping.py
|
||||||
|
|
||||||
|
## Thanks to our Contributors
|
||||||
|
|
||||||
|
This release contains contributions from many people at Google, as well as:
|
||||||
|
|
||||||
|
1e100, 4d55397500, a6802739, abenmao, Adam Weiss, Ag Ramesh, Alan Du, Albin Joy,
|
||||||
|
Alex, Aman Patel, Amit, Amit Kumar Jaiswal, Amit Srivastava, Andreas Eberle,
|
||||||
|
Andy Craze, Anthony Platanios, Armen Poghosov, armenpoghosov, arp95, Arpit Shah,
|
||||||
|
Ashwin Ramaswami, Aurelien Geron, AuréLien Geron, aweers, awesomealex1, Ayush
|
||||||
|
Agrawal, Ben Barsdell, Bharat Raghunathan, Bhavani Subramanian, blairhan,
|
||||||
|
BléNesi Attila, Brandon Carter, candy.dc, Chao Liu, chenchc, chie8842, Christian
|
||||||
|
Hansen, Christian Sigg, Clayne Robison, crafet, csukuangfj, ctiijima, Dan
|
||||||
|
Jarvis, Dan Lazewatsky, Daniel Ingram, Daniel Salvadori, Dave Airlie, David
|
||||||
|
Norman, Dayananda V, Dayananda-V, delock, Denis Khalikov, Deven Desai, Dheeraj
|
||||||
|
Rajaram Reddy, dmitrievanthony, Donovan Ong, Drew Szurko, Duncan Riach, Dustin
|
||||||
|
Neighly, Edward Forgacs, EFanZh, Fei Hu, Felix Lemke, Filip Matzner, fo40225,
|
||||||
|
frreiss, Gautam, gehring, Geoffrey Irving, Grzegorz George Pawelczak, Grzegorz
|
||||||
|
Pawelczak, Gyoung-Yoon Ryoo, HanGuo97, Hanton Yang, Hari Shankar, hehongliang,
|
||||||
|
Heungsub Lee, Hoeseong Kim, I-Hong Jhuo, Ilango R, Innovimax, Irene Dea, Jacky
|
||||||
|
Ko, Jakub Lipinski, Jason Zaman, jcf94, Jeffrey Poznanovic, Jens Elofsson,
|
||||||
|
Jeroen BéDorf, Jia Qingtong, Jiankang, Joe Q, Joe Quadrino, Joeran Beel, Jonas
|
||||||
|
Rauber, Jonathan, Jonathan Kyl, Joppe Geluykens, Joseph Friedman, jtressle, jwu,
|
||||||
|
K Yasaswi Sri Chandra Gandhi, K. Hodges, Kaixi Hou, Karl Lessard, Karl
|
||||||
|
Weinmeister, Karthik Muthuraman, Kashif Rasul, KDR, Keno Fischer, Kevin Mader,
|
||||||
|
kjopek, Koan-Sin Tan, kouml, ktaebum, Lakshay Tokas, Laurent Le Brun, Letian
|
||||||
|
Kang, Li, Guizi, Loo Rong Jie, Lucas Hendren, Lukas Geiger, Luke Han, luxupu,
|
||||||
|
Ma, Guokai, Mahmoud Abuzaina, Mandar Deshpande, manhyuk, Marco Gaido, Marek
|
||||||
|
Drozdowski, Mark Collier, Mark Ryan, mars20, Mateusz Chudyk, Matt Conley,
|
||||||
|
MattConley, mbhuiyan, mdfaijul, Melissa Grueter, Michael KäUfl, MickaëL
|
||||||
|
Schoentgen, Miguel Morin, Mihail Salnikov, Mike Arpaia, Mike Holcomb, monklof,
|
||||||
|
Moses Marin, Mshr-H, nammbash, Natalia Gimelshein, Nayana-Ibm, neargye, Neeraj
|
||||||
|
Pradhan, Nehal J Wani, Nick, Niels Ole Salscheider, Niranjan Hasabnis, nlewycky,
|
||||||
|
Nuka-137, Nutti, olicht, P Sudeepam, Palmer Lao, Pan Daoxin, Pariksheet Pinjari,
|
||||||
|
Pavel Samolysov, PENGWA, Pooya Davoodi, R S Nikhil Krishna, Rohit Gupta, Roman
|
||||||
|
Soldatow, rthadur, Ruizhe, Ryan Jiang, Samantha Andow, Sami Kama, Sana-Damani,
|
||||||
|
Saurabh Deoras, sdamani, seanshpark, Sebastien Iooss, Serv-Inc, Shahzad Lone,
|
||||||
|
Shashank Gupta, Shashi, shashvat, shashvatshahi1998, Siju, Siju Samuel,
|
||||||
|
Snease-Abq, Spencer Schaber, sremedios, srinivasan.narayanamoorthy, Steve Lang,
|
||||||
|
Steve Nesae, Sumesh Udayakumaran, Supriya Rao, Taylor Jakobson, Taylor Thornton,
|
||||||
|
Ted Chang, ThisIsPIRI, Thomas Deegan, Thomas Hagebols, tianyapiaozi, Tim Zaman,
|
||||||
|
tomguluson92, Tongxuan Liu, TungJerry, v1incent, Vagif, vcarpani, Vikram Tiwari,
|
||||||
|
Vishwak Srinivasan, Vitor-Alves, wangsiyu, wateryzephyr, WeberXie, WeijieSun,
|
||||||
|
Wen-Heng (Jack) Chung, wenxizhu, Will Battel, William D. Irons, wyzhao, Xin,
|
||||||
|
Yasuhiro Matsumoto, ymodak, Yong Tang, Younes Khoudli, Yuan Lin, Yves-Noel
|
||||||
|
Weweler, Zantares, zjjott, 卜居, 王振华 (Wang Zhenhua), 黄鑫
|
||||||
|
|
||||||
|
# Release 1.12.3
|
||||||
|
|
||||||
|
## Bug Fixes and Other Changes
|
||||||
|
|
||||||
|
* Updates `png_archive` dependency to 1.6.37 to not be affected by
|
||||||
|
CVE-2019-7317, CVE-2018-13785, and CVE-2018-14048.
|
||||||
|
* Updates `sqlite` depenency to 3.28.0 to not be affected by CVE-2018-20506,
|
||||||
|
CVE-2018-20346, and CVE-2018-20505.
|
||||||
|
|
||||||
# Release 1.12.2
|
# Release 1.12.2
|
||||||
|
|
||||||
## Bug Fixes and Other Changes
|
## Bug Fixes and Other Changes
|
||||||
|
@ -41,22 +41,30 @@ from tensorflow.python.tools import module_util as _module_util
|
|||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
|
# WRAPPER_PLACEHOLDER
|
||||||
|
|
||||||
# Make sure directory containing top level submodules is in
|
# Make sure directory containing top level submodules is in
|
||||||
# the __path__ so that "from tensorflow.foo import bar" works.
|
# the __path__ so that "from tensorflow.foo import bar" works.
|
||||||
# We're using bitwise, but there's nothing special about that.
|
# We're using bitwise, but there's nothing special about that.
|
||||||
_API_MODULE = bitwise # pylint: disable=undefined-variable
|
_API_MODULE = _sys.modules[__name__].bitwise
|
||||||
_current_module = _sys.modules[__name__]
|
|
||||||
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
|
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
|
||||||
|
_current_module = _sys.modules[__name__]
|
||||||
|
|
||||||
if not hasattr(_current_module, '__path__'):
|
if not hasattr(_current_module, '__path__'):
|
||||||
__path__ = [_tf_api_dir]
|
__path__ = [_tf_api_dir]
|
||||||
elif _tf_api_dir not in __path__:
|
elif _tf_api_dir not in __path__:
|
||||||
__path__.append(_tf_api_dir)
|
__path__.append(_tf_api_dir)
|
||||||
|
|
||||||
# Hook external TensorFlow modules.
|
# Hook external TensorFlow modules.
|
||||||
|
|
||||||
|
# Import compat before trying to import summary from tensorboard, so that
|
||||||
|
# reexport_tf_summary can get compat from sys.modules
|
||||||
|
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
|
||||||
try:
|
try:
|
||||||
from tensorboard.summary._tf import summary
|
from tensorboard.summary._tf import summary
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "summary", summary)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_logging.warning(
|
_logging.warning(
|
||||||
"Limited tf.summary API due to missing TensorBoard installation.")
|
"Limited tf.summary API due to missing TensorBoard installation.")
|
||||||
@ -65,6 +73,7 @@ try:
|
|||||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "estimator", estimator)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -72,6 +81,7 @@ try:
|
|||||||
from tensorflow.python.keras.api._v2 import keras
|
from tensorflow.python.keras.api._v2 import keras
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "keras", keras)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -122,25 +132,17 @@ if _running_from_pip_package():
|
|||||||
# pylint: disable=undefined-variable
|
# pylint: disable=undefined-variable
|
||||||
try:
|
try:
|
||||||
del python
|
del python
|
||||||
if '__all__' in vars():
|
|
||||||
vars()['__all__'].remove('python')
|
|
||||||
del core
|
|
||||||
if '__all__' in vars():
|
|
||||||
vars()['__all__'].remove('core')
|
|
||||||
except NameError:
|
except NameError:
|
||||||
# Don't fail if these modules are not available.
|
|
||||||
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
|
|
||||||
# does not have 'python', 'core' directories. Then, it will be copied
|
|
||||||
# to tensorflow/ which does have these two directories.
|
|
||||||
pass
|
pass
|
||||||
# Similarly for compiler. Do it separately to make sure we do this even if the
|
try:
|
||||||
# others don't exist.
|
del core
|
||||||
|
except NameError:
|
||||||
|
pass
|
||||||
try:
|
try:
|
||||||
del compiler
|
del compiler
|
||||||
if '__all__' in vars():
|
|
||||||
vars()['__all__'].remove('compiler')
|
|
||||||
except NameError:
|
except NameError:
|
||||||
pass
|
pass
|
||||||
|
# pylint: enable=undefined-variable
|
||||||
|
|
||||||
# Add module aliases
|
# Add module aliases
|
||||||
if hasattr(_current_module, 'keras'):
|
if hasattr(_current_module, 'keras'):
|
||||||
@ -148,6 +150,8 @@ if hasattr(_current_module, 'keras'):
|
|||||||
metrics = keras.metrics
|
metrics = keras.metrics
|
||||||
optimizers = keras.optimizers
|
optimizers = keras.optimizers
|
||||||
initializers = keras.initializers
|
initializers = keras.initializers
|
||||||
|
setattr(_current_module, "losses", losses)
|
||||||
compat.v2.compat.v1 = compat.v1
|
setattr(_current_module, "metrics", metrics)
|
||||||
|
setattr(_current_module, "optimizers", optimizers)
|
||||||
|
setattr(_current_module, "initializers", initializers)
|
||||||
# pylint: enable=undefined-variable
|
# pylint: enable=undefined-variable
|
||||||
|
@ -30,10 +30,12 @@ from tensorflow.python.tools import module_util as _module_util
|
|||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
|
# WRAPPER_PLACEHOLDER
|
||||||
|
|
||||||
# Make sure directory containing top level submodules is in
|
# Make sure directory containing top level submodules is in
|
||||||
# the __path__ so that "from tensorflow.foo import bar" works.
|
# the __path__ so that "from tensorflow.foo import bar" works.
|
||||||
# We're using bitwise, but there's nothing special about that.
|
# We're using bitwise, but there's nothing special about that.
|
||||||
_API_MODULE = bitwise # pylint: disable=undefined-variable
|
_API_MODULE = _sys.modules[__name__].bitwise # pylint: disable=undefined-variable
|
||||||
_current_module = _sys.modules[__name__]
|
_current_module = _sys.modules[__name__]
|
||||||
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
|
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
|
||||||
if not hasattr(_current_module, '__path__'):
|
if not hasattr(_current_module, '__path__'):
|
||||||
@ -46,6 +48,7 @@ try:
|
|||||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "estimator", estimator)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -53,6 +56,7 @@ try:
|
|||||||
from tensorflow.python.keras.api._v1 import keras
|
from tensorflow.python.keras.api._v1 import keras
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "keras", keras)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -77,9 +81,8 @@ if '__all__' in vars():
|
|||||||
|
|
||||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||||
# The 'app' module will be imported as part of the placeholder section above.
|
# The 'app' module will be imported as part of the placeholder section above.
|
||||||
app.flags = flags # pylint: disable=undefined-variable
|
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
||||||
if '__all__' in vars():
|
setattr(_current_module, "flags", flags)
|
||||||
vars()['__all__'].append('flags')
|
|
||||||
|
|
||||||
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
|
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
|
||||||
# running under pip.
|
# running under pip.
|
||||||
@ -122,25 +125,16 @@ if _running_from_pip_package():
|
|||||||
# pylint: disable=undefined-variable
|
# pylint: disable=undefined-variable
|
||||||
try:
|
try:
|
||||||
del python
|
del python
|
||||||
if '__all__' in vars():
|
|
||||||
vars()['__all__'].remove('python')
|
|
||||||
del core
|
|
||||||
if '__all__' in vars():
|
|
||||||
vars()['__all__'].remove('core')
|
|
||||||
except NameError:
|
except NameError:
|
||||||
# Don't fail if these modules are not available.
|
|
||||||
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
|
|
||||||
# does not have 'python', 'core' directories. Then, it will be copied
|
|
||||||
# to tensorflow/ which does have these two directories.
|
|
||||||
pass
|
pass
|
||||||
# Similarly for compiler. Do it separately to make sure we do this even if the
|
try:
|
||||||
# others don't exist.
|
del core
|
||||||
|
except NameError:
|
||||||
|
pass
|
||||||
try:
|
try:
|
||||||
del compiler
|
del compiler
|
||||||
if '__all__' in vars():
|
|
||||||
vars()['__all__'].remove('compiler')
|
|
||||||
except NameError:
|
except NameError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
compat.v2.compat.v1 = compat.v1
|
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
|
||||||
# pylint: enable=undefined-variable
|
# pylint: enable=undefined-variable
|
||||||
|
@ -64,6 +64,7 @@ tf_cuda_library(
|
|||||||
}) + [
|
}) + [
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||||
|
"//tensorflow/core/distributed_runtime/eager:remote_mgr",
|
||||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
||||||
|
@ -54,6 +54,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/rendezvous.h"
|
#include "tensorflow/core/framework/rendezvous.h"
|
||||||
@ -260,12 +261,14 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||||
|
|
||||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||||
|
auto remote_mgr =
|
||||||
|
absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/true);
|
||||||
|
|
||||||
return ctx->context->InitializeRemoteMaster(
|
return ctx->context->InitializeRemoteMaster(
|
||||||
std::move(server), grpc_server->worker_env(), worker_session,
|
std::move(server), grpc_server->worker_env(), worker_session,
|
||||||
std::move(remote_eager_workers), std::move(remote_device_mgr),
|
std::move(remote_eager_workers), std::move(remote_device_mgr),
|
||||||
remote_workers, context_id, r, device_mgr, keep_alive_secs,
|
remote_workers, context_id, r, device_mgr, keep_alive_secs,
|
||||||
worker_session->cluster_flr.get());
|
worker_session->cluster_flr.get(), std::move(remote_mgr));
|
||||||
#undef LOG_AND_RETURN_IF_ERROR
|
#undef LOG_AND_RETURN_IF_ERROR
|
||||||
}
|
}
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
|
@ -890,9 +890,19 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
|||||||
// Stop the tape from recording
|
// Stop the tape from recording
|
||||||
pop_backward_tape.release()();
|
pop_backward_tape.release()();
|
||||||
|
|
||||||
|
if (grad.size() != in_grads.size()) {
|
||||||
|
return tensorflow::errors::Internal("Wrong number of gradients returned.");
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int64> targets;
|
std::vector<int64> targets;
|
||||||
|
std::vector<Gradient*> used_in_grads;
|
||||||
|
// We may end up with slightly fewer elements than we reserve, but grad.size()
|
||||||
|
// should be a reasonably tight upper bound.
|
||||||
|
targets.reserve(grad.size());
|
||||||
|
used_in_grads.reserve(grad.size());
|
||||||
std::unordered_map<int64, TapeTensor> sources_that_are_targets;
|
std::unordered_map<int64, TapeTensor> sources_that_are_targets;
|
||||||
for (Gradient* grad_tensor : grad) {
|
for (int grad_index = 0; grad_index < grad.size(); ++grad_index) {
|
||||||
|
Gradient* grad_tensor = grad[grad_index];
|
||||||
if (grad_tensor != nullptr) {
|
if (grad_tensor != nullptr) {
|
||||||
int64 tensor_id = vspace_.TensorId(grad_tensor);
|
int64 tensor_id = vspace_.TensorId(grad_tensor);
|
||||||
targets.push_back(tensor_id);
|
targets.push_back(tensor_id);
|
||||||
@ -900,23 +910,18 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
|||||||
sources_that_are_targets.emplace(
|
sources_that_are_targets.emplace(
|
||||||
tensor_id, vspace_.TapeTensorFromGradient(grad_tensor));
|
tensor_id, vspace_.TapeTensorFromGradient(grad_tensor));
|
||||||
}
|
}
|
||||||
}
|
Gradient* in_grad = in_grads[grad_index];
|
||||||
}
|
if (in_grad != nullptr) {
|
||||||
if (targets.size() > in_grads.size()) {
|
// ComputeGradient steals a reference
|
||||||
return tensorflow::errors::Internal("Too many gradients returned.");
|
vspace_.MarkAsResult(in_grad);
|
||||||
}
|
}
|
||||||
|
used_in_grads.push_back(in_grad);
|
||||||
for (int target_index = 0; target_index < targets.size(); ++target_index) {
|
|
||||||
Gradient* in_grad = in_grads[target_index];
|
|
||||||
Gradient* grad_tensor = grad[target_index];
|
|
||||||
if (grad_tensor != nullptr && in_grad != nullptr) {
|
|
||||||
// ComputeGradient steals a reference
|
|
||||||
vspace_.MarkAsResult(in_grad);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return tape->ComputeGradient(vspace_, targets, sources,
|
return tape->ComputeGradient(vspace_, targets, sources,
|
||||||
sources_that_are_targets, in_grads, out_grads);
|
sources_that_are_targets, used_in_grads,
|
||||||
|
out_grads);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||||
|
@ -28,12 +28,16 @@ from tensorflow.python.tools import module_util as _module_util
|
|||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
|
# WRAPPER_PLACEHOLDER
|
||||||
|
|
||||||
# Hook external TensorFlow modules.
|
# Hook external TensorFlow modules.
|
||||||
_current_module = _sys.modules[__name__]
|
_current_module = _sys.modules[__name__]
|
||||||
try:
|
try:
|
||||||
from tensorboard.summary._tf import summary
|
from tensorboard.summary._tf import summary
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
||||||
|
# Make sure we get the correct summary module with lazy loading
|
||||||
|
setattr(_current_module, "summary", summary)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_logging.warning(
|
_logging.warning(
|
||||||
"Limited tf.compat.v2.summary API due to missing TensorBoard "
|
"Limited tf.compat.v2.summary API due to missing TensorBoard "
|
||||||
@ -43,6 +47,7 @@ try:
|
|||||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "estimator", estimator)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -50,6 +55,7 @@ try:
|
|||||||
from tensorflow.python.keras.api._v2 import keras
|
from tensorflow.python.keras.api._v2 import keras
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "keras", keras)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -61,11 +67,15 @@ except ImportError:
|
|||||||
#
|
#
|
||||||
# This make this one symbol available directly.
|
# This make this one symbol available directly.
|
||||||
from tensorflow.python.compat.v2_compat import enable_v2_behavior # pylint: disable=g-import-not-at-top
|
from tensorflow.python.compat.v2_compat import enable_v2_behavior # pylint: disable=g-import-not-at-top
|
||||||
|
setattr(_current_module, "enable_v2_behavior", enable_v2_behavior)
|
||||||
|
|
||||||
# Add module aliases
|
# Add module aliases
|
||||||
_current_module = _sys.modules[__name__]
|
|
||||||
if hasattr(_current_module, 'keras'):
|
if hasattr(_current_module, 'keras'):
|
||||||
losses = keras.losses
|
losses = keras.losses
|
||||||
metrics = keras.metrics
|
metrics = keras.metrics
|
||||||
optimizers = keras.optimizers
|
optimizers = keras.optimizers
|
||||||
initializers = keras.initializers
|
initializers = keras.initializers
|
||||||
|
setattr(_current_module, "losses", losses)
|
||||||
|
setattr(_current_module, "metrics", metrics)
|
||||||
|
setattr(_current_module, "optimizers", optimizers)
|
||||||
|
setattr(_current_module, "initializers", initializers)
|
||||||
|
@ -27,12 +27,15 @@ from tensorflow.python.tools import module_util as _module_util
|
|||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
|
# WRAPPER_PLACEHOLDER
|
||||||
|
|
||||||
# Hook external TensorFlow modules.
|
# Hook external TensorFlow modules.
|
||||||
_current_module = _sys.modules[__name__]
|
_current_module = _sys.modules[__name__]
|
||||||
try:
|
try:
|
||||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "estimator", estimator)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -40,9 +43,11 @@ try:
|
|||||||
from tensorflow.python.keras.api._v1 import keras
|
from tensorflow.python.keras.api._v1 import keras
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "keras", keras)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||||
app.flags = flags # pylint: disable=undefined-variable
|
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
||||||
|
setattr(_current_module, "flags", flags)
|
||||||
|
@ -22,6 +22,7 @@ load(
|
|||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
"tf_copts",
|
"tf_copts",
|
||||||
)
|
)
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags")
|
||||||
|
|
||||||
def tf_library(
|
def tf_library(
|
||||||
name,
|
name,
|
||||||
@ -180,13 +181,7 @@ def tf_library(
|
|||||||
# `find` on such an object.
|
# `find` on such an object.
|
||||||
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
|
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
|
||||||
|
|
||||||
# Pass --target_cpu=haswell to tfcompile if compiling for Haswell (bazel
|
flags = tfcompile_extra_flags() + flags
|
||||||
# build --cpu=haswell). We put it at the beginning of the flags list so
|
|
||||||
# that tfcompile_flags can override if if desired.
|
|
||||||
flags = select({
|
|
||||||
"//tools/target_cpu:haswell": "--target_cpu=haswell ",
|
|
||||||
"//conditions:default": "",
|
|
||||||
}) + flags
|
|
||||||
|
|
||||||
if enable_xla_hlo_profiling:
|
if enable_xla_hlo_profiling:
|
||||||
profiling_flag = "--xla_hlo_profile"
|
profiling_flag = "--xla_hlo_profile"
|
||||||
|
@ -1087,8 +1087,6 @@ Status Encapsulator::MakePrunedGraphCopyAndInline(
|
|||||||
FunctionDefToBodyHelper(*fdef, node->attrs(), library, &fbody));
|
FunctionDefToBodyHelper(*fdef, node->attrs(), library, &fbody));
|
||||||
|
|
||||||
InlineFunctionBodyOptions inline_opts;
|
InlineFunctionBodyOptions inline_opts;
|
||||||
inline_opts.override_device = false;
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(InlineFunctionBody(*library, pruned_graph->get(), node,
|
TF_RETURN_IF_ERROR(InlineFunctionBody(*library, pruned_graph->get(), node,
|
||||||
fbody.get(), inline_opts));
|
fbody.get(), inline_opts));
|
||||||
}
|
}
|
||||||
|
@ -183,6 +183,9 @@ class XlaAssignVariableOp : public OpKernel {
|
|||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \
|
Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \
|
||||||
data::AnonymousIteratorHandleOp); \
|
data::AnonymousIteratorHandleOp); \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("DeleteIterator").Device(DEVICE).HostMemory("deleter"), \
|
||||||
|
data::DeleteIteratorOp); \
|
||||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
|
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
|
||||||
data::IteratorGetNextOp); \
|
data::IteratorGetNextOp); \
|
||||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \
|
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \
|
||||||
|
@ -462,9 +462,16 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "tf_tfl_translate_main",
|
||||||
|
srcs = [
|
||||||
|
"tf_tfl_translate.cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_binary(
|
tf_cc_binary(
|
||||||
name = "tf_tfl_translate",
|
name = "tf_tfl_translate",
|
||||||
srcs = ["tf_tfl_translate.cc"],
|
srcs = [":tf_tfl_translate_main"],
|
||||||
deps = [
|
deps = [
|
||||||
":flatbuffer_translate_lib",
|
":flatbuffer_translate_lib",
|
||||||
":tensorflow_lite",
|
":tensorflow_lite",
|
||||||
|
@ -26,11 +26,11 @@ namespace tflite {
|
|||||||
// Error reporter that reports errors via the module's emitError.
|
// Error reporter that reports errors via the module's emitError.
|
||||||
class EmitErrorReporter : public ErrorReporter {
|
class EmitErrorReporter : public ErrorReporter {
|
||||||
public:
|
public:
|
||||||
explicit EmitErrorReporter(mlir::Module module) : module_(module) {}
|
explicit EmitErrorReporter(mlir::ModuleOp module) : module_(module) {}
|
||||||
int Report(const char* format, va_list args) override;
|
int Report(const char* format, va_list args) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mlir::Module module_;
|
mlir::ModuleOp module_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -42,6 +42,13 @@ static tflite::Padding ConvertTFL_PaddingAttrForOptionWriter(
|
|||||||
.Case("VALID", tflite::Padding_VALID);
|
.Case("VALID", tflite::Padding_VALID);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static tflite::MirrorPadMode ConvertTFL_MirrorPaddingAttrForOptionWriter(
|
||||||
|
llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
|
||||||
|
return llvm::StringSwitch<tflite::MirrorPadMode>(str)
|
||||||
|
.Case("REFLECT", tflite::MirrorPadMode_REFLECT)
|
||||||
|
.Case("SYMMETRIC", tflite::MirrorPadMode_SYMMETRIC);
|
||||||
|
}
|
||||||
|
|
||||||
static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter(
|
static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter(
|
||||||
mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
|
mlir::Type type, flatbuffers::FlatBufferBuilder* builder) {
|
||||||
switch (type.getKind()) {
|
switch (type.getKind()) {
|
||||||
|
@ -77,9 +77,9 @@ using llvm::Twine;
|
|||||||
using mlir::Block;
|
using mlir::Block;
|
||||||
using mlir::Dialect;
|
using mlir::Dialect;
|
||||||
using mlir::ElementsAttr;
|
using mlir::ElementsAttr;
|
||||||
using mlir::Function;
|
using mlir::FuncOp;
|
||||||
using mlir::MLIRContext;
|
using mlir::MLIRContext;
|
||||||
using mlir::Module;
|
using mlir::ModuleOp;
|
||||||
using mlir::NoneType;
|
using mlir::NoneType;
|
||||||
using mlir::openOutputFile;
|
using mlir::openOutputFile;
|
||||||
using mlir::Operation;
|
using mlir::Operation;
|
||||||
@ -167,6 +167,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
|||||||
return tflite::TensorType_STRING;
|
return tflite::TensorType_STRING;
|
||||||
case mlir::TF::TensorFlowTypes::COMPLEX64:
|
case mlir::TF::TensorFlowTypes::COMPLEX64:
|
||||||
return tflite::TensorType_COMPLEX64;
|
return tflite::TensorType_COMPLEX64;
|
||||||
|
case mlir::TF::TensorFlowTypes::UINT8:
|
||||||
|
return tflite::TensorType_UINT8;
|
||||||
case mlir::StandardTypes::Integer: {
|
case mlir::StandardTypes::Integer: {
|
||||||
const auto& itype = type.cast<mlir::IntegerType>();
|
const auto& itype = type.cast<mlir::IntegerType>();
|
||||||
switch (itype.getWidth()) {
|
switch (itype.getWidth()) {
|
||||||
@ -243,18 +245,18 @@ static bool HasValidTFLiteType(Value* value, T& error_handler) {
|
|||||||
// TODO(hinsu): Now that translation is done by making a single pass over the
|
// TODO(hinsu): Now that translation is done by making a single pass over the
|
||||||
// MLIR module, consider inlining these validation checks at the place where
|
// MLIR module, consider inlining these validation checks at the place where
|
||||||
// these invariants are assumed instead of checking upfront.
|
// these invariants are assumed instead of checking upfront.
|
||||||
static bool IsValidTFLiteMlirModule(Module module) {
|
static bool IsValidTFLiteMlirModule(ModuleOp module) {
|
||||||
MLIRContext* context = module.getContext();
|
MLIRContext* context = module.getContext();
|
||||||
|
|
||||||
// Verify that module has a function named main.
|
// Verify that module has a function named main.
|
||||||
Function main_fn = module.getNamedFunction("main");
|
FuncOp main_fn = module.lookupSymbol<FuncOp>("main");
|
||||||
if (!main_fn) {
|
if (!main_fn) {
|
||||||
return emitError(UnknownLoc::get(context),
|
return emitError(UnknownLoc::get(context),
|
||||||
"should have a function named 'main'"),
|
"should have a function named 'main'"),
|
||||||
false;
|
false;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto fn : module.getOps<Function>()) {
|
for (auto fn : module.getOps<FuncOp>()) {
|
||||||
if (fn.getBlocks().size() != 1) {
|
if (fn.getBlocks().size() != 1) {
|
||||||
return fn.emitError("should have exactly one basic block"), false;
|
return fn.emitError("should have exactly one basic block"), false;
|
||||||
}
|
}
|
||||||
@ -323,14 +325,14 @@ class Translator {
|
|||||||
// Translates the given MLIR module into TFLite FlatBuffer format and returns
|
// Translates the given MLIR module into TFLite FlatBuffer format and returns
|
||||||
// the serialized output. Returns llvm::None on unsupported, invalid inputs or
|
// the serialized output. Returns llvm::None on unsupported, invalid inputs or
|
||||||
// internal error.
|
// internal error.
|
||||||
static Optional<std::string> Translate(Module module,
|
static Optional<std::string> Translate(ModuleOp module,
|
||||||
bool emit_builtin_tflite_ops,
|
bool emit_builtin_tflite_ops,
|
||||||
bool emit_select_tf_ops,
|
bool emit_select_tf_ops,
|
||||||
bool emit_custom_ops);
|
bool emit_custom_ops);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
|
enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
|
||||||
explicit Translator(Module module, bool emit_builtin_tflite_ops,
|
explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
|
||||||
bool emit_select_tf_ops, bool emit_custom_ops)
|
bool emit_select_tf_ops, bool emit_custom_ops)
|
||||||
: module_(module), builder_(kInitialBufferSize) {
|
: module_(module), builder_(kInitialBufferSize) {
|
||||||
// The first buffer must be empty according to the schema definition.
|
// The first buffer must be empty according to the schema definition.
|
||||||
@ -391,11 +393,11 @@ class Translator {
|
|||||||
Operation* inst, const std::vector<int32_t>& operands,
|
Operation* inst, const std::vector<int32_t>& operands,
|
||||||
const std::vector<int32_t>& results);
|
const std::vector<int32_t>& results);
|
||||||
|
|
||||||
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(Function fn);
|
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(FuncOp fn);
|
||||||
|
|
||||||
// Uses the tf.entry_function attribute (if set) to initialize the op to name
|
// Uses the tf.entry_function attribute (if set) to initialize the op to name
|
||||||
// mapping.
|
// mapping.
|
||||||
void InitializeNamesFromAttribute(Function fn);
|
void InitializeNamesFromAttribute(FuncOp fn);
|
||||||
|
|
||||||
// Returns a unique name for `op`.
|
// Returns a unique name for `op`.
|
||||||
std::string UniqueName(mlir::Operation* op);
|
std::string UniqueName(mlir::Operation* op);
|
||||||
@ -403,7 +405,7 @@ class Translator {
|
|||||||
// Returns a unique name starting with a given prefix.
|
// Returns a unique name starting with a given prefix.
|
||||||
std::string UniqueName(llvm::StringRef prefix);
|
std::string UniqueName(llvm::StringRef prefix);
|
||||||
|
|
||||||
Module module_;
|
ModuleOp module_;
|
||||||
|
|
||||||
flatbuffers::FlatBufferBuilder builder_;
|
flatbuffers::FlatBufferBuilder builder_;
|
||||||
BufferOffset<tflite::Buffer> empty_buffer_;
|
BufferOffset<tflite::Buffer> empty_buffer_;
|
||||||
@ -449,10 +451,16 @@ std::string Translator::GetName(Operation* inst) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string Translator::UniqueName(llvm::StringRef prefix) {
|
std::string Translator::UniqueName(llvm::StringRef prefix) {
|
||||||
|
// Keep incrementing the counter until we find a unique name.
|
||||||
std::string name = prefix;
|
std::string name = prefix;
|
||||||
auto& val = name_to_count_[name];
|
int64_t& prefix_count = name_to_count_[name];
|
||||||
if (val) name = (prefix + llvm::Twine(val)).str();
|
int64_t val = prefix_count;
|
||||||
++val;
|
while (val != 0) {
|
||||||
|
name = (prefix + llvm::Twine(prefix_count)).str();
|
||||||
|
++prefix_count;
|
||||||
|
val = name_to_count_[name];
|
||||||
|
}
|
||||||
|
name_to_count_[name] = 1;
|
||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -781,7 +789,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
|||||||
llvm::None;
|
llvm::None;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Translator::InitializeNamesFromAttribute(Function fn) {
|
void Translator::InitializeNamesFromAttribute(FuncOp fn) {
|
||||||
auto dict_attr = fn.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
|
auto dict_attr = fn.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
|
||||||
if (!dict_attr) return;
|
if (!dict_attr) return;
|
||||||
|
|
||||||
@ -794,8 +802,10 @@ void Translator::InitializeNamesFromAttribute(Function fn) {
|
|||||||
fn.emitWarning() << "invalid entry function specification";
|
fn.emitWarning() << "invalid entry function specification";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (auto it : llvm::enumerate(fn.getArguments()))
|
for (auto it : llvm::enumerate(fn.getArguments())) {
|
||||||
op_to_name_[*it.value()->user_begin()] = input_names[it.index()];
|
op_to_name_[*it.value()->user_begin()] = input_names[it.index()];
|
||||||
|
++name_to_count_[input_names[it.index()].str()];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto str = dict_attr.get("outputs").dyn_cast<mlir::StringAttr>()) {
|
if (auto str = dict_attr.get("outputs").dyn_cast<mlir::StringAttr>()) {
|
||||||
@ -813,18 +823,19 @@ void Translator::InitializeNamesFromAttribute(Function fn) {
|
|||||||
// ensure the name that will be assigned to the buffer is the same, or
|
// ensure the name that will be assigned to the buffer is the same, or
|
||||||
// insert an op so that we can have a buffer named such. This cannot
|
// insert an op so that we can have a buffer named such. This cannot
|
||||||
// currently happen due to pseudo_input nodes.
|
// currently happen due to pseudo_input nodes.
|
||||||
if (auto op = it.value()->getDefiningOp())
|
if (auto op = it.value()->getDefiningOp()) {
|
||||||
op_to_name_[op] = output_names[it.index()];
|
op_to_name_[op] = output_names[it.index()];
|
||||||
else
|
name_to_count_[output_names[it.index()].str()] = 1;
|
||||||
|
} else {
|
||||||
fn.emitWarning() << "output is not due to an op and '"
|
fn.emitWarning() << "output is not due to an op and '"
|
||||||
<< output_names[it.index()]
|
<< output_names[it.index()]
|
||||||
<< "' may not be a named output";
|
<< "' may not be a named output";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
|
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||||
Function fn) {
|
|
||||||
InitializeNamesFromAttribute(fn);
|
InitializeNamesFromAttribute(fn);
|
||||||
std::vector<BufferOffset<tflite::Tensor>> tensors;
|
std::vector<BufferOffset<tflite::Tensor>> tensors;
|
||||||
llvm::DenseMap<Value*, int> tensor_index_map;
|
llvm::DenseMap<Value*, int> tensor_index_map;
|
||||||
@ -927,7 +938,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
|
|||||||
/*name=*/builder_.CreateString(fn.getName().str()));
|
/*name=*/builder_.CreateString(fn.getName().str()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Optional<std::string> Translator::Translate(Module module,
|
Optional<std::string> Translator::Translate(ModuleOp module,
|
||||||
bool emit_builtin_tflite_ops,
|
bool emit_builtin_tflite_ops,
|
||||||
bool emit_select_tf_ops,
|
bool emit_select_tf_ops,
|
||||||
bool emit_custom_ops) {
|
bool emit_custom_ops) {
|
||||||
@ -941,14 +952,14 @@ Optional<std::string> Translator::TranslateInternal() {
|
|||||||
// Create a list of functions in the module with main function being the
|
// Create a list of functions in the module with main function being the
|
||||||
// first function in the list. This is required as the first subgraph in the
|
// first function in the list. This is required as the first subgraph in the
|
||||||
// model is entry point for the model.
|
// model is entry point for the model.
|
||||||
std::vector<Function> functions;
|
std::vector<FuncOp> functions;
|
||||||
functions.reserve(std::distance(module_.begin(), module_.end()));
|
functions.reserve(std::distance(module_.begin(), module_.end()));
|
||||||
|
|
||||||
int subgraph_idx = 0;
|
int subgraph_idx = 0;
|
||||||
Function main_fn = module_.getNamedFunction("main");
|
FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
|
||||||
subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++;
|
subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++;
|
||||||
functions.push_back(main_fn);
|
functions.push_back(main_fn);
|
||||||
for (auto fn : module_.getOps<Function>()) {
|
for (auto fn : module_.getOps<FuncOp>()) {
|
||||||
if (fn == main_fn) continue;
|
if (fn == main_fn) continue;
|
||||||
|
|
||||||
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
|
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
|
||||||
@ -992,7 +1003,7 @@ Optional<std::string> Translator::TranslateInternal() {
|
|||||||
// * Ops with variable tensors
|
// * Ops with variable tensors
|
||||||
//
|
//
|
||||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
bool tflite::MlirToFlatBufferTranslateFunction(
|
||||||
Module module, std::string* serialized_flatbuffer,
|
ModuleOp module, std::string* serialized_flatbuffer,
|
||||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
|
bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
|
||||||
bool emit_custom_ops) {
|
bool emit_custom_ops) {
|
||||||
auto maybe_translated = Translator::Translate(
|
auto maybe_translated = Translator::Translate(
|
||||||
@ -1003,7 +1014,7 @@ bool tflite::MlirToFlatBufferTranslateFunction(
|
|||||||
}
|
}
|
||||||
|
|
||||||
static mlir::LogicalResult MlirToFlatBufferFileTranslateFunction(
|
static mlir::LogicalResult MlirToFlatBufferFileTranslateFunction(
|
||||||
Module module, llvm::StringRef filename) {
|
ModuleOp module, llvm::StringRef filename) {
|
||||||
std::string serialized_flatbuffer;
|
std::string serialized_flatbuffer;
|
||||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||||
module, &serialized_flatbuffer, emit_builtin_tflite_ops,
|
module, &serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||||
|
@ -32,7 +32,7 @@ namespace tflite {
|
|||||||
|
|
||||||
// Translates the given MLIR `module` into a FlatBuffer and stores the
|
// Translates the given MLIR `module` into a FlatBuffer and stores the
|
||||||
// serialized flatbuffer into the string.
|
// serialized flatbuffer into the string.
|
||||||
bool MlirToFlatBufferTranslateFunction(mlir::Module module,
|
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
|
||||||
std::string *serialized_flatbuffer,
|
std::string *serialized_flatbuffer,
|
||||||
bool emit_builtin_tflite_ops,
|
bool emit_builtin_tflite_ops,
|
||||||
bool emit_select_tf_ops,
|
bool emit_select_tf_ops,
|
||||||
|
@ -50,6 +50,13 @@ def TFL_Str : Type<CPred<"$_self.isa<mlir::TF::StringType>()">,
|
|||||||
"TFLite string type">,
|
"TFLite string type">,
|
||||||
BuildableType<"getType<mlir::TF::StringType>()">;
|
BuildableType<"getType<mlir::TF::StringType>()">;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TFLite dialect uint8 type - uses the TF uint8 type as implementation
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
def TFL_Uint8 : Type<CPred<"$_self.isa<mlir::TF::Uint8Type>()">,
|
||||||
|
"TFLite uint8 type">,
|
||||||
|
BuildableType<"getType<mlir::TF::Uint8Type>()">;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Activation function enum definitions.
|
// Activation function enum definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -77,11 +84,17 @@ def TFL_AFAttr : StrEnumAttr<
|
|||||||
// These should match the padding enum in TFLite schema.
|
// These should match the padding enum in TFLite schema.
|
||||||
def TFL_PAD_Same : StrEnumAttrCase<"SAME">;
|
def TFL_PAD_Same : StrEnumAttrCase<"SAME">;
|
||||||
def TFL_PAD_Valid : StrEnumAttrCase<"VALID">;
|
def TFL_PAD_Valid : StrEnumAttrCase<"VALID">;
|
||||||
|
def TFL_MIRRORPAD_Reflect : StrEnumAttrCase<"REFLECT">;
|
||||||
|
def TFL_MIRRORPAD_Symmetric : StrEnumAttrCase<"SYMMETRIC">;
|
||||||
|
|
||||||
def TFL_PaddingAttr : StrEnumAttr<"Padding", "padding enum", [
|
def TFL_PaddingAttr : StrEnumAttr<"Padding", "padding enum", [
|
||||||
TFL_PAD_Same, TFL_PAD_Valid
|
TFL_PAD_Same, TFL_PAD_Valid
|
||||||
]>;
|
]>;
|
||||||
|
|
||||||
|
def TFL_MirrorPaddingAttr : StrEnumAttr<"Padding", "Mirror pad enum", [
|
||||||
|
TFL_MIRRORPAD_Reflect, TFL_MIRRORPAD_Symmetric
|
||||||
|
]>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Min-max range pair definitions.
|
// Min-max range pair definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -432,13 +445,13 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins Variadic<TensorOf<[F32, I64, I32, I16, I8, TFL_QI8]>>:$values,
|
ins Variadic<TensorOf<[F32, I64, I32, I16, I8, TFL_QI8, TFL_Uint8]>>:$values,
|
||||||
I32Attr:$axis,
|
I32Attr:$axis,
|
||||||
TFL_AFAttr:$fused_activation_function
|
TFL_AFAttr:$fused_activation_function
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TensorOf<[F32, I64, I32, I16, I8, TFL_QI8]>:$output
|
TensorOf<[F32, I64, I32, I16, I8, TFL_QI8, TFL_Uint8]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
@ -553,9 +566,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [Broadcastable, NoSideEffect]> {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
// TODO(haoliang): missing Uint8
|
ins TensorOf<[F32, I32, I64, I8, TFL_Uint8]>:$lhs,
|
||||||
ins TensorOf<[F32, I32, I64, I8]>:$lhs,
|
TensorOf<[F32, I32, I64, I8, TFL_Uint8]>:$rhs);
|
||||||
TensorOf<[F32, I32, I64, I8]>:$rhs);
|
|
||||||
|
|
||||||
let results = (outs TFL_BoolTensor:$output);
|
let results = (outs TFL_BoolTensor:$output);
|
||||||
|
|
||||||
@ -665,10 +677,9 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, Broadcastable,
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
// TODO: missing Uint8
|
|
||||||
ins
|
ins
|
||||||
TensorOf<[I1, F32, I32, I64, I8]>:$x,
|
TensorOf<[I1, F32, I32, I64, I8, TFL_Uint8]>:$x,
|
||||||
TensorOf<[I1, F32, I32, I64, I8]>:$y
|
TensorOf<[I1, F32, I32, I64, I8, TFL_Uint8]>:$y
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs TFL_BoolTensor:$output);
|
let results = (outs TFL_BoolTensor:$output);
|
||||||
@ -1066,8 +1077,7 @@ def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect]> {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
// TODO: missing uint8
|
TensorOf<[F32, I8, I32, I64, TFL_Uint8]>:$input,
|
||||||
TensorOf<[F32, I8, I32, I64]>:$input,
|
|
||||||
TensorOf<[I32, I64]>:$axis,
|
TensorOf<[I32, I64]>:$axis,
|
||||||
BoolAttr:$keep_dims
|
BoolAttr:$keep_dims
|
||||||
);
|
);
|
||||||
@ -1686,9 +1696,10 @@ def TFL_TanhOp: TFL_Op<"tanh", [
|
|||||||
Computes element-wise Hyperbolic tangent of input
|
Computes element-wise Hyperbolic tangent of input
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins AnyTensor:$x);
|
// TODO(haoliang): missing Uint8.
|
||||||
|
let arguments = (ins TensorOf<[F32, I16, I8]>:$x);
|
||||||
|
|
||||||
let results = (outs AnyTensor:$y);
|
let results = (outs TensorOf<[F32, I16, I8]>:$y);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
|
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
|
||||||
@ -1958,6 +1969,61 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
|
|||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TFL_CastOp : TFL_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||||
|
let summary = "Cast operator";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Casts input from input type to output type.
|
||||||
|
}];
|
||||||
|
|
||||||
|
// TODO(b/135538711): Add complex types here.
|
||||||
|
let arguments = (ins
|
||||||
|
TensorOf<[F32, I1, I32, I64]>:$input
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs TensorOf<[F32, I1, I32, I64]>:$output);
|
||||||
|
|
||||||
|
// TFLite's cast op does not utilize CastOptions, instead derives types
|
||||||
|
// from the TfLiteTensors.
|
||||||
|
let hasOptions = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
|
||||||
|
NoSideEffect, TFL_OperandHasRank<1, 2>]> {
|
||||||
|
let summary = "MirrorPad Operator. Pads a tensor with mirrored values.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
This operation pads a input with mirrored values according to the paddings
|
||||||
|
you specify. paddings is an integer tensor with shape [n, 2],
|
||||||
|
where n is the rank of input.
|
||||||
|
For each dimension D of input, paddings[D, 0] indicates how many values
|
||||||
|
to add before the contents of input in that dimension,
|
||||||
|
and paddings[D, 1] indicates how many values to add after the contents of
|
||||||
|
input in that dimension.
|
||||||
|
|
||||||
|
Both paddings[D, 0] and paddings[D, 1] must be no greater than
|
||||||
|
input.dim_size(D) (or input.dim_size(D) - 1)
|
||||||
|
if copy_border is true (if false, respectively).
|
||||||
|
|
||||||
|
The padded size of each dimension D of the output is:
|
||||||
|
|
||||||
|
paddings(D, 0) + input.dim_size(D) + paddings(D, 1)
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
// TODO: add uint8 support when ready.
|
||||||
|
TensorOf<[F32, I32, I64]>:$input,
|
||||||
|
TensorOf<[I32, I64]>:$pad,
|
||||||
|
TFL_MirrorPaddingAttr:$mode
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TensorOf<[F32, I32, I64]>:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasOptions = 1;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Quantization ops.
|
// Quantization ops.
|
||||||
|
@ -75,7 +75,7 @@ class FixedResultUniformScale {
|
|||||||
Builder builder(op->getContext());
|
Builder builder(op->getContext());
|
||||||
IntegerType storage_type = builder.getIntegerType(BitWidth);
|
IntegerType storage_type = builder.getIntegerType(BitWidth);
|
||||||
const double scale = static_cast<double>(ScaleMantissa) *
|
const double scale = static_cast<double>(ScaleMantissa) *
|
||||||
::exp10(static_cast<double>(ScaleExp));
|
::pow(10.0, static_cast<double>(ScaleExp));
|
||||||
return UniformQuantizedType::getChecked(
|
return UniformQuantizedType::getChecked(
|
||||||
Sign, storage_type, result_type.getElementType(), scale, ZeroPoint,
|
Sign, storage_type, result_type.getElementType(), scale, ZeroPoint,
|
||||||
StorageTypeMin, StorageTypeMax, builder.getUnknownLoc());
|
StorageTypeMin, StorageTypeMax, builder.getUnknownLoc());
|
||||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/PrettyStackTrace.h"
|
#include "llvm/Support/PrettyStackTrace.h"
|
||||||
#include "llvm/Support/SMLoc.h"
|
#include "llvm/Support/SMLoc.h"
|
||||||
#include "llvm/Support/SourceMgr.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||||
#include "mlir/Parser.h" // TF:local_config_mlir
|
#include "mlir/Parser.h" // TF:local_config_mlir
|
||||||
@ -98,7 +99,7 @@ int main(int argc, char** argv) {
|
|||||||
if (!module) return 1;
|
if (!module) return 1;
|
||||||
|
|
||||||
// TODO(jpienaar): Expand to support inputs.
|
// TODO(jpienaar): Expand to support inputs.
|
||||||
mlir::Function main = module->getNamedFunction("main");
|
mlir::FuncOp main = module->lookupSymbol<mlir::FuncOp>("main");
|
||||||
QCHECK(main) << "No 'main' function specified.";
|
QCHECK(main) << "No 'main' function specified.";
|
||||||
if (main.getType().getNumInputs() != 0)
|
if (main.getType().getNumInputs() != 0)
|
||||||
LOG(QFATAL) << "NYI: Only nullary functions supported.";
|
LOG(QFATAL) << "NYI: Only nullary functions supported.";
|
||||||
|
@ -43,6 +43,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
|||||||
return DT_INT64;
|
return DT_INT64;
|
||||||
case toco::IODataType::STRING:
|
case toco::IODataType::STRING:
|
||||||
return DT_STRING;
|
return DT_STRING;
|
||||||
|
case toco::IODataType::BOOL:
|
||||||
|
return DT_BOOL;
|
||||||
default:
|
default:
|
||||||
return DT_INVALID;
|
return DT_INVALID;
|
||||||
}
|
}
|
||||||
|
@ -829,3 +829,39 @@ func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor
|
|||||||
// CHECK-LABEL: slice1Tensor
|
// CHECK-LABEL: slice1Tensor
|
||||||
// CHECK: "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
// CHECK: "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @mirror_pad(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
|
||||||
|
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
|
||||||
|
%0 = "tf.MirrorPad"(%arg0, %arg1) { mode = "SYMMETRIC" }: (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32>
|
||||||
|
return %0#0 : tensor<? x f32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: mirror_pad
|
||||||
|
// CHECK: %0 = "tfl.mirror_pad"(%arg0, %arg1) {mode = "SYMMETRIC"} : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<?xf32>
|
||||||
|
// CHECK: return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @mirror_pad_reflect(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
|
||||||
|
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
|
||||||
|
%0 = "tf.MirrorPad"(%arg0, %arg1) { mode = "REFLECT" }: (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32>
|
||||||
|
return %0#0 : tensor<? x f32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: mirror_pad_reflect
|
||||||
|
// CHECK: %0 = "tfl.mirror_pad"(%arg0, %arg1) {mode = "REFLECT"} : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<?xf32>
|
||||||
|
// CHECK: return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @Tanh(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||||
|
%2 = "tf.Tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
|
return %2: tensor<1xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: Tanh
|
||||||
|
// CHECK: %0 = "tfl.tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> {
|
||||||
|
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
|
||||||
|
return %0 : tensor<1x2x2x5xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: cast
|
||||||
|
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
|
||||||
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||||
|
|
||||||
func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
|
func @main(tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
attributes {tf.entry_function = {inputs = "input", outputs = "SameNameAsOutput"}} {
|
||||||
^bb0(%arg0: tensor<3x2xi32>):
|
^bb0(%arg0: tensor<3x2xi32>):
|
||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
@ -14,7 +15,7 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
|
|||||||
// CHECK-NEXT: shape: [ 3, 2 ],
|
// CHECK-NEXT: shape: [ 3, 2 ],
|
||||||
// CHECK-NEXT: type: INT32,
|
// CHECK-NEXT: type: INT32,
|
||||||
// CHECK-NEXT: buffer: 1,
|
// CHECK-NEXT: buffer: 1,
|
||||||
// CHECK-NEXT: name: "Input",
|
// CHECK-NEXT: name: "input",
|
||||||
// CHECK-NEXT: quantization: {
|
// CHECK-NEXT: quantization: {
|
||||||
// CHECK-EMPTY:
|
// CHECK-EMPTY:
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
@ -38,7 +39,7 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
|
|||||||
// CHECK-NEXT: shape: [ ],
|
// CHECK-NEXT: shape: [ ],
|
||||||
// CHECK-NEXT: type: INT32,
|
// CHECK-NEXT: type: INT32,
|
||||||
// CHECK-NEXT: buffer: 4,
|
// CHECK-NEXT: buffer: 4,
|
||||||
// CHECK-NEXT: name: "Const2",
|
// CHECK-NEXT: name: "SameNameAsOutput1",
|
||||||
// CHECK-NEXT: quantization: {
|
// CHECK-NEXT: quantization: {
|
||||||
// CHECK-EMPTY:
|
// CHECK-EMPTY:
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
@ -46,7 +47,7 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
|
|||||||
// CHECK-NEXT: shape: [ ],
|
// CHECK-NEXT: shape: [ ],
|
||||||
// CHECK-NEXT: type: INT32,
|
// CHECK-NEXT: type: INT32,
|
||||||
// CHECK-NEXT: buffer: 5,
|
// CHECK-NEXT: buffer: 5,
|
||||||
// CHECK-NEXT: name: "add",
|
// CHECK-NEXT: name: "SameNameAsOutput",
|
||||||
// CHECK-NEXT: quantization: {
|
// CHECK-NEXT: quantization: {
|
||||||
// CHECK-EMPTY:
|
// CHECK-EMPTY:
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
@ -90,7 +91,7 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
|
|||||||
%0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input")
|
%0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input")
|
||||||
%1 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> loc("Const")
|
%1 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> loc("Const")
|
||||||
%2 = "tfl.sub" (%0, %1) {fused_activation_function = "RELU6"} : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> loc("sub")
|
%2 = "tfl.sub" (%0, %1) {fused_activation_function = "RELU6"} : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> loc("sub")
|
||||||
%3 = "std.constant" () {value = dense<10> : tensor<i32>} : () -> tensor<i32> loc("Const2")
|
%3 = "std.constant" () {value = dense<10> : tensor<i32>} : () -> tensor<i32> loc("SameNameAsOutput")
|
||||||
%4 = "tfl.add" (%3, %2) {fused_activation_function = "NONE"} : (tensor<i32>, tensor<3x2xi32>) -> tensor<3x2xi32> loc("add")
|
%4 = "tfl.add" (%3, %2) {fused_activation_function = "NONE"} : (tensor<i32>, tensor<3x2xi32>) -> tensor<3x2xi32> loc("add")
|
||||||
return %4 : tensor<3x2xi32>
|
return %4 : tensor<3x2xi32>
|
||||||
}
|
}
|
||||||
|
@ -817,7 +817,7 @@ func @testConcatInvalidOutputElementalType(%arg0: tensor<2xi32>, %arg1: tensor<2
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @testConcatInvalidStorageType(%arg0: tensor<2x!quant.uniform<i9:f32, 0.1:128>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>> {
|
func @testConcatInvalidStorageType(%arg0: tensor<2x!quant.uniform<i9:f32, 0.1:128>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>> {
|
||||||
// expected-error @+1 {{'tfl.concatenation' op operand #0 must be tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or quantized type with 8 bits storage type values}}
|
// expected-error @+1 {{'tfl.concatenation' op operand #0 must be tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or quantized type with 8 bits storage type or TFLite uint8 type values}}
|
||||||
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2x!quant.uniform<i9:f32, 0.1:128>>, tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
|
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2x!quant.uniform<i9:f32, 0.1:128>>, tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
|
||||||
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
|
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/SourceMgr.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
#include "llvm/Support/ToolOutputFile.h"
|
#include "llvm/Support/ToolOutputFile.h"
|
||||||
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
|
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||||
@ -32,8 +33,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
|
|
||||||
|
using mlir::FuncOp;
|
||||||
using mlir::MLIRContext;
|
using mlir::MLIRContext;
|
||||||
using mlir::Module;
|
using mlir::ModuleOp;
|
||||||
using stream_executor::port::StatusOr;
|
using stream_executor::port::StatusOr;
|
||||||
using tensorflow::Status;
|
using tensorflow::Status;
|
||||||
|
|
||||||
@ -47,7 +49,7 @@ static llvm::cl::opt<bool> print_function_result_mapping(
|
|||||||
enum TranslationStatus { kTrSuccess, kTrFailure };
|
enum TranslationStatus { kTrSuccess, kTrFailure };
|
||||||
|
|
||||||
static int PrintFunctionResultMapping(const std::string &result,
|
static int PrintFunctionResultMapping(const std::string &result,
|
||||||
Module module) {
|
ModuleOp module) {
|
||||||
// Build model from the resultant string to extract the return values from
|
// Build model from the resultant string to extract the return values from
|
||||||
// their source of truth.
|
// their source of truth.
|
||||||
auto model =
|
auto model =
|
||||||
@ -83,7 +85,7 @@ static int PrintFunctionResultMapping(const std::string &result,
|
|||||||
std::cout << '\'' << subgraph_name << "' outputs:\n";
|
std::cout << '\'' << subgraph_name << "' outputs:\n";
|
||||||
mlir::Operation *terminator = nullptr;
|
mlir::Operation *terminator = nullptr;
|
||||||
if (subgraph->name()) {
|
if (subgraph->name()) {
|
||||||
if (auto fn = module.getNamedFunction(subgraph->name()->str()))
|
if (auto fn = module.lookupSymbol<FuncOp>(subgraph->name()->str()))
|
||||||
terminator = fn.back().getTerminator();
|
terminator = fn.back().getTerminator();
|
||||||
}
|
}
|
||||||
i = 0;
|
i = 0;
|
||||||
|
@ -34,7 +34,7 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
using mlir::MLIRContext;
|
using mlir::MLIRContext;
|
||||||
using mlir::Module;
|
using mlir::ModuleOp;
|
||||||
using mlir::OwningModuleRef;
|
using mlir::OwningModuleRef;
|
||||||
using stream_executor::port::StatusOr;
|
using stream_executor::port::StatusOr;
|
||||||
|
|
||||||
@ -87,8 +87,8 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
|||||||
context);
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ShouldRunQuantizePasses(mlir::Module m) {
|
bool ShouldRunQuantizePasses(mlir::ModuleOp m) {
|
||||||
if (mlir::Function main_fn = m.getNamedFunction("main")) {
|
if (mlir::FuncOp main_fn = m.lookupSymbol<mlir::FuncOp>("main")) {
|
||||||
return main_fn.getAttrOfType<mlir::UnitAttr>("tf.quantize") !=
|
return main_fn.getAttrOfType<mlir::UnitAttr>("tf.quantize") !=
|
||||||
mlir::Attribute();
|
mlir::Attribute();
|
||||||
}
|
}
|
||||||
@ -100,6 +100,16 @@ void AddTFToTFLConversionPasses(bool emit_builtin_tflite_ops, bool run_quantize,
|
|||||||
bool lower_tensor_list_ops,
|
bool lower_tensor_list_ops,
|
||||||
mlir::PassManager *pass_manager) {
|
mlir::PassManager *pass_manager) {
|
||||||
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
|
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
|
||||||
|
|
||||||
|
if (lower_tensor_list_ops) {
|
||||||
|
// Execute this pass before `CanonicalizerPass` in case some TensorList
|
||||||
|
// ops are constant folded into variant types.
|
||||||
|
// TODO(b/137125056): Move this pass after `CanonicalizerPass` after we
|
||||||
|
// handle constant ops that produce `TensorList`.
|
||||||
|
// TODO(haoliang): Add this pass by default.
|
||||||
|
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(jpienaar): Revise post dialect constants.
|
// TODO(jpienaar): Revise post dialect constants.
|
||||||
pass_manager->addPass(mlir::TF::CreateDecodeConstantPass());
|
pass_manager->addPass(mlir::TF::CreateDecodeConstantPass());
|
||||||
// Canonicalization includes const folding, which is utilized here to optimize
|
// Canonicalization includes const folding, which is utilized here to optimize
|
||||||
@ -112,10 +122,6 @@ void AddTFToTFLConversionPasses(bool emit_builtin_tflite_ops, bool run_quantize,
|
|||||||
if (emit_builtin_tflite_ops) {
|
if (emit_builtin_tflite_ops) {
|
||||||
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
|
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
|
||||||
// the TFLite dialect.
|
// the TFLite dialect.
|
||||||
// TODO(haoliang): Add this pass by default.
|
|
||||||
if (lower_tensor_list_ops) {
|
|
||||||
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
|
||||||
}
|
|
||||||
pass_manager->addPass(mlir::TFL::CreatePrepareTFPass());
|
pass_manager->addPass(mlir::TFL::CreatePrepareTFPass());
|
||||||
pass_manager->addPass(mlir::createCanonicalizerPass());
|
pass_manager->addPass(mlir::createCanonicalizerPass());
|
||||||
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
|
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
|
||||||
@ -132,7 +138,7 @@ void AddTFToTFLConversionPasses(bool emit_builtin_tflite_ops, bool run_quantize,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status ConvertTFControlFlowToTFLOrFlatbuffer(
|
Status ConvertTFControlFlowToTFLOrFlatbuffer(
|
||||||
mlir::Module module, bool export_to_mlir, bool emit_builtin_tflite_ops,
|
mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
|
||||||
bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops,
|
bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops,
|
||||||
bool lower_tensor_list_ops, std::string *result) {
|
bool lower_tensor_list_ops, std::string *result) {
|
||||||
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
|
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
|
||||||
|
@ -46,7 +46,7 @@ LoadFromGraphdefOrMlirSource(
|
|||||||
// attribute "tf.quantize" by the importer module.
|
// attribute "tf.quantize" by the importer module.
|
||||||
// TODO(fengliuai): switch to the cmd flag once the flags are moved to this
|
// TODO(fengliuai): switch to the cmd flag once the flags are moved to this
|
||||||
// file with main method.
|
// file with main method.
|
||||||
bool ShouldRunQuantizePasses(mlir::Module m);
|
bool ShouldRunQuantizePasses(mlir::ModuleOp m);
|
||||||
|
|
||||||
// Add the MLIR passes that convert TF control flow dialect to TF Lite dialect
|
// Add the MLIR passes that convert TF control flow dialect to TF Lite dialect
|
||||||
// to a MLIR `pass_manager`. These passes first raise the control flow in the TF
|
// to a MLIR `pass_manager`. These passes first raise the control flow in the TF
|
||||||
@ -69,7 +69,7 @@ void AddTFToTFLConversionPasses(bool emit_builtin_tflite_ops, bool run_quantize,
|
|||||||
// main function, Quantization is applied. If `export_to_mlir` is true, the
|
// main function, Quantization is applied. If `export_to_mlir` is true, the
|
||||||
// result is exported in MLIR text format, otherwise exported in flat buffer.
|
// result is exported in MLIR text format, otherwise exported in flat buffer.
|
||||||
Status ConvertTFControlFlowToTFLOrFlatbuffer(
|
Status ConvertTFControlFlowToTFLOrFlatbuffer(
|
||||||
mlir::Module module, bool export_to_mlir, bool emit_builtin_tflite_ops,
|
mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
|
||||||
bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops,
|
bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops,
|
||||||
bool lower_tensor_list_ops, std::string* result);
|
bool lower_tensor_list_ops, std::string* result);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -131,6 +131,7 @@ def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>;
|
|||||||
def : Pat<(TF_SliceOp $input, $begin, $size), (TFL_SliceOp $input, $begin, $size)>;
|
def : Pat<(TF_SliceOp $input, $begin, $size), (TFL_SliceOp $input, $begin, $size)>;
|
||||||
def : Pat<(TF_SoftmaxOp $arg), (TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>;
|
def : Pat<(TF_SoftmaxOp $arg), (TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>;
|
||||||
def : Pat<(TF_SqueezeOp $arg, $squeeze_dims), (TFL_SqueezeOp $arg, $squeeze_dims)>;
|
def : Pat<(TF_SqueezeOp $arg, $squeeze_dims), (TFL_SqueezeOp $arg, $squeeze_dims)>;
|
||||||
|
def : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
|
||||||
def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>;
|
def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>;
|
||||||
def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
|
def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
|
||||||
|
|
||||||
@ -228,18 +229,25 @@ def : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2), (TFL_MeanOp $arg0, $arg1, $a
|
|||||||
|
|
||||||
def : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), (TFL_SumOp $arg, $axes, $arg2)>;
|
def : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), (TFL_SumOp $arg, $axes, $arg2)>;
|
||||||
|
|
||||||
|
// TopK in TFL is always sorted so we ignore that attribute here.
|
||||||
|
def : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted), (TFL_TopKV2Op $input, $k)>;
|
||||||
|
|
||||||
def : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMinOp $arg0, $arg1, $arg2)>;
|
def : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMinOp $arg0, $arg1, $arg2)>;
|
||||||
|
|
||||||
def : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMaxOp $arg0, $arg1, $arg2)>;
|
def : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMaxOp $arg0, $arg1, $arg2)>;
|
||||||
|
|
||||||
def : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceProdOp $arg0, $arg1, $arg2)>;
|
def : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceProdOp $arg0, $arg1, $arg2)>;
|
||||||
|
|
||||||
|
def : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
|
||||||
|
|
||||||
def : Pat<(TF_BatchToSpaceNDOp $input, $block_shape, $crops), (TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>;
|
def : Pat<(TF_BatchToSpaceNDOp $input, $block_shape, $crops), (TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>;
|
||||||
|
|
||||||
def : Pat<(TF_SpaceToBatchNDOp $input, $block_shape, $paddings), (TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>;
|
def : Pat<(TF_SpaceToBatchNDOp $input, $block_shape, $paddings), (TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>;
|
||||||
|
|
||||||
def : Pat<(TF_ResizeBilinearOp $images, $size, $align_corners, ConstBoolAttrFalse:$half_pixel_centers), (TFL_ResizeBilinearOp $images, $size, $align_corners)>;
|
def : Pat<(TF_ResizeBilinearOp $images, $size, $align_corners, ConstBoolAttrFalse:$half_pixel_centers), (TFL_ResizeBilinearOp $images, $size, $align_corners)>;
|
||||||
|
|
||||||
|
def : Pat<(TF_MirrorPadOp $arg0, $arg1, $cst), (TFL_MirrorPadOp $arg0, $arg1, $cst)>;
|
||||||
|
|
||||||
def : Pat<
|
def : Pat<
|
||||||
(TF_StridedSliceOp $input, $begin, $end, $strides, $begin_mask, $end_mask, $ellipsis_mask, $new_axis_mask, $shrink_axis_mask),
|
(TF_StridedSliceOp $input, $begin, $end, $strides, $begin_mask, $end_mask, $ellipsis_mask, $new_axis_mask, $shrink_axis_mask),
|
||||||
(TFL_StridedSliceOp $input, $begin, $end, $strides,
|
(TFL_StridedSliceOp $input, $begin, $end, $strides,
|
||||||
|
@ -72,7 +72,6 @@ DECL_CONVERT_OP(MatMul);
|
|||||||
DECL_CONVERT_OP(Pack);
|
DECL_CONVERT_OP(Pack);
|
||||||
DECL_CONVERT_OP(Split);
|
DECL_CONVERT_OP(Split);
|
||||||
DECL_CONVERT_OP(SplitV);
|
DECL_CONVERT_OP(SplitV);
|
||||||
DECL_CONVERT_OP(TopKV2);
|
|
||||||
DECL_CONVERT_OP(Unpack);
|
DECL_CONVERT_OP(Unpack);
|
||||||
|
|
||||||
#undef DECL_CONVERT_OP
|
#undef DECL_CONVERT_OP
|
||||||
@ -207,14 +206,6 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
|
|||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
|
|
||||||
PatternMatchResult ConvertTFTopKV2Op::matchAndRewrite(
|
|
||||||
Operation* op, PatternRewriter& rewriter) const {
|
|
||||||
// TopK in TFL is always sorted so we ignore that attribute here.
|
|
||||||
rewriter.replaceOpWithNewOp<TFL::TopKV2Op>(op, op->getOperand(0),
|
|
||||||
op->getOperand(1));
|
|
||||||
return matchSuccess();
|
|
||||||
}
|
|
||||||
|
|
||||||
PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
|
PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
|
||||||
Operation* op, PatternRewriter& rewriter) const {
|
Operation* op, PatternRewriter& rewriter) const {
|
||||||
auto tf_unpack_op = cast<TF::UnpackOp>(op);
|
auto tf_unpack_op = cast<TF::UnpackOp>(op);
|
||||||
@ -239,7 +230,7 @@ void LegalizeTF::runOnFunction() {
|
|||||||
populateWithGenerated(ctx, &patterns);
|
populateWithGenerated(ctx, &patterns);
|
||||||
RewriteListBuilder<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFGatherOp,
|
RewriteListBuilder<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFGatherOp,
|
||||||
ConvertTFGatherV2Op, ConvertTFMatMulOp, ConvertTFPackOp,
|
ConvertTFGatherV2Op, ConvertTFMatMulOp, ConvertTFPackOp,
|
||||||
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFTopKV2Op,
|
ConvertTFSplitOp, ConvertTFSplitVOp,
|
||||||
ConvertTFUnpackOp>::build(patterns, ctx);
|
ConvertTFUnpackOp>::build(patterns, ctx);
|
||||||
applyPatternsGreedily(func, std::move(patterns));
|
applyPatternsGreedily(func, std::move(patterns));
|
||||||
}
|
}
|
||||||
|
@ -61,7 +61,7 @@ namespace {
|
|||||||
|
|
||||||
class TensorListPatternRewriter : public PatternRewriter {
|
class TensorListPatternRewriter : public PatternRewriter {
|
||||||
public:
|
public:
|
||||||
explicit TensorListPatternRewriter(Function fn)
|
explicit TensorListPatternRewriter(FuncOp fn)
|
||||||
: PatternRewriter(fn.getBody()) {}
|
: PatternRewriter(fn.getBody()) {}
|
||||||
|
|
||||||
Operation *createOperation(const OperationState &state) override {
|
Operation *createOperation(const OperationState &state) override {
|
||||||
@ -77,7 +77,7 @@ struct LowerStaticTensorListPass
|
|||||||
void runOnModule() override;
|
void runOnModule() override;
|
||||||
|
|
||||||
// Apply type and op changes within a function.
|
// Apply type and op changes within a function.
|
||||||
LogicalResult RewriteFunction(Function func,
|
LogicalResult RewriteFunction(FuncOp func,
|
||||||
TensorListPatternRewriter *rewriter);
|
TensorListPatternRewriter *rewriter);
|
||||||
|
|
||||||
// Changes the function type of `cond_func` and `body_func`, and the result
|
// Changes the function type of `cond_func` and `body_func`, and the result
|
||||||
@ -276,8 +276,8 @@ LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType(
|
|||||||
|
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
auto module = getModule();
|
auto module = getModule();
|
||||||
Function cond_func = module.getNamedFunction(while_op->getCond());
|
FuncOp cond_func = module.lookupSymbol<FuncOp>(while_op->getCond());
|
||||||
Function body_func = module.getNamedFunction(while_op->getBody());
|
FuncOp body_func = module.lookupSymbol<FuncOp>(while_op->getBody());
|
||||||
|
|
||||||
if (cond_func) {
|
if (cond_func) {
|
||||||
// Change `cond_func`'s argument types to `unranked_argument_types`.
|
// Change `cond_func`'s argument types to `unranked_argument_types`.
|
||||||
@ -327,7 +327,7 @@ LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType(
|
|||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||||
Function func, TensorListPatternRewriter *rewriter) {
|
FuncOp func, TensorListPatternRewriter *rewriter) {
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
|
|
||||||
for (Block &block : func) {
|
for (Block &block : func) {
|
||||||
@ -388,7 +388,7 @@ void LowerStaticTensorListPass::runOnModule() {
|
|||||||
// have a potential issue when one function taking a `DT_VARIANT` is processed
|
// have a potential issue when one function taking a `DT_VARIANT` is processed
|
||||||
// before the function that produces the `DT_VARIANT`. We need to carefully
|
// before the function that produces the `DT_VARIANT`. We need to carefully
|
||||||
// order the functions to be processed.
|
// order the functions to be processed.
|
||||||
std::vector<Function> funcs_in_module;
|
std::vector<FuncOp> funcs_in_module;
|
||||||
for (auto func : getModule().getOps<FuncOp>()) {
|
for (auto func : getModule().getOps<FuncOp>()) {
|
||||||
// Always place the main function to be the first in the list.
|
// Always place the main function to be the first in the list.
|
||||||
if (func.getName() == "main") {
|
if (func.getName() == "main") {
|
||||||
|
@ -48,7 +48,7 @@ class PostQuantizePass : public FunctionPass<PostQuantizePass> {
|
|||||||
bool emit_quant_adaptor_ops_;
|
bool emit_quant_adaptor_ops_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void RemoveQuantizationAdaptorOps(Function func) {
|
void RemoveQuantizationAdaptorOps(FuncOp func) {
|
||||||
mlir::OpBuilder builder(func.getBody());
|
mlir::OpBuilder builder(func.getBody());
|
||||||
auto& bb = func.getBlocks().front();
|
auto& bb = func.getBlocks().front();
|
||||||
auto* terminator = bb.getTerminator();
|
auto* terminator = bb.getTerminator();
|
||||||
|
@ -50,52 +50,19 @@ struct QuantizePass : public FunctionPass<QuantizePass> {
|
|||||||
|
|
||||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc"
|
#include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc"
|
||||||
|
|
||||||
struct QuantizeConcatOp : public RewritePattern {
|
|
||||||
explicit QuantizeConcatOp(MLIRContext* context)
|
|
||||||
: RewritePattern(QuantizeOp::getOperationName(), 1, context) {}
|
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation* op,
|
|
||||||
PatternRewriter& rewriter) const override;
|
|
||||||
};
|
|
||||||
|
|
||||||
PatternMatchResult mlir::TFL::QuantizeConcatOp::matchAndRewrite(
|
|
||||||
Operation* op, PatternRewriter& rewriter) const {
|
|
||||||
auto quantize_op = cast<QuantizeOp>(op);
|
|
||||||
auto concat_op =
|
|
||||||
dyn_cast_or_null<ConcatenationOp>(quantize_op.input()->getDefiningOp());
|
|
||||||
if (!concat_op) {
|
|
||||||
return matchFailure();
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value*, 4> values;
|
|
||||||
values.reserve(concat_op.getNumOperands());
|
|
||||||
for (auto operand : concat_op.values()) {
|
|
||||||
if (auto opInst =
|
|
||||||
dyn_cast_or_null<DequantizeOp>(operand->getDefiningOp())) {
|
|
||||||
values.push_back(opInst.input());
|
|
||||||
} else {
|
|
||||||
return matchFailure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rewriter.replaceOpWithNewOp<TFL::ConcatenationOp>(
|
|
||||||
op, quantize_op.output()->getType(), values,
|
|
||||||
rewriter.getI32IntegerAttr(concat_op.axis().getZExtValue()),
|
|
||||||
rewriter.getStringAttr(concat_op.fused_activation_function()));
|
|
||||||
return matchSuccess();
|
|
||||||
}
|
|
||||||
|
|
||||||
void QuantizePass::runOnFunction() {
|
void QuantizePass::runOnFunction() {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
auto* ctx = func.getContext();
|
auto* ctx = func.getContext();
|
||||||
TFL::populateWithGenerated(ctx, &patterns);
|
TFL::populateWithGenerated(ctx, &patterns);
|
||||||
mlir::RewriteListBuilder<mlir::TFL::QuantizeConcatOp>::build(patterns, ctx);
|
mlir::RewriteListBuilder<mlir::TFL::GenericFullQuantizationPattern<
|
||||||
|
mlir::TFL::QuantizeOp, mlir::TFL::DequantizeOp>>::build(patterns, ctx);
|
||||||
applyPatternsGreedily(func, std::move(patterns));
|
applyPatternsGreedily(func, std::move(patterns));
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
|
// Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass.
|
||||||
FunctionPassBase *CreateQuantizePass() { return new QuantizePass(); }
|
FunctionPassBase* CreateQuantizePass() { return new QuantizePass(); }
|
||||||
|
|
||||||
static PassRegistration<QuantizePass> pass(
|
static PassRegistration<QuantizePass> pass(
|
||||||
"tfl-quantize", "Apply quantization on models in TensorFlow Lite dialect");
|
"tfl-quantize", "Apply quantization on models in TensorFlow Lite dialect");
|
||||||
|
@ -22,10 +22,6 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
|||||||
// Quantize attribute $0 by using quantization parameter from %1.
|
// Quantize attribute $0 by using quantization parameter from %1.
|
||||||
def QuantizeByQuantizedType : NativeCodeCall<"Quantize($0, $1.getValue())">;
|
def QuantizeByQuantizedType : NativeCodeCall<"Quantize($0, $1.getValue())">;
|
||||||
|
|
||||||
// Call the generic builder of `op`. Use the result type of $0 in the new op.
|
|
||||||
class ReplaceWith<string op> : NativeCodeCall<"$_builder.create<" # op #
|
|
||||||
">($0->getLoc(), $0->getResult(0)->getType(), $1, $2, $3)">;
|
|
||||||
|
|
||||||
// Squash tfl.dequantize and tfl.quantize pairs.
|
// Squash tfl.dequantize and tfl.quantize pairs.
|
||||||
// TODO(fengliuai): Compare the scale of input and output. This can also be
|
// TODO(fengliuai): Compare the scale of input and output. This can also be
|
||||||
// squashed to a requantize op if the scales are different.
|
// squashed to a requantize op if the scales are different.
|
||||||
@ -39,98 +35,3 @@ def : Pat<(TFL_QuantizeOp
|
|||||||
(TFL_QConstOp
|
(TFL_QConstOp
|
||||||
$qtype,
|
$qtype,
|
||||||
(QuantizeByQuantizedType $value, $qtype))>;
|
(QuantizeByQuantizedType $value, $qtype))>;
|
||||||
|
|
||||||
// Quantize the AddOp if both inputs are dequantized and the output is
|
|
||||||
// quantized.
|
|
||||||
def : Pat<(TFL_QuantizeOp:$q
|
|
||||||
(TFL_AddOp (TFL_DequantizeOp $lhs), (TFL_DequantizeOp $rhs),
|
|
||||||
$fused_activation_function),
|
|
||||||
$output_type),
|
|
||||||
(ReplaceWith<"TFL::AddOp"> $q, $lhs, $rhs,
|
|
||||||
$fused_activation_function)>;
|
|
||||||
|
|
||||||
// Quantize the Conv2DOp if the input and weight are dequantized. The scale of
|
|
||||||
// the bias input is determined by the scales of input and weight operands.
|
|
||||||
def : Pat<(TFL_QuantizeOp
|
|
||||||
(TFL_Conv2DOp
|
|
||||||
(TFL_DequantizeOp $in),
|
|
||||||
(TFL_DequantizeOp $weight),
|
|
||||||
(TFL_DequantizeOp $bias),
|
|
||||||
$dilation_h_factor,
|
|
||||||
$dilation_w_factor,
|
|
||||||
$fused_activation_function,
|
|
||||||
$padding,
|
|
||||||
$stride_h,
|
|
||||||
$stride_w),
|
|
||||||
$output_type),
|
|
||||||
(TFL_Conv2DOp
|
|
||||||
$in,
|
|
||||||
$weight,
|
|
||||||
$bias,
|
|
||||||
$dilation_h_factor,
|
|
||||||
$dilation_w_factor,
|
|
||||||
$fused_activation_function,
|
|
||||||
$padding,
|
|
||||||
$stride_h,
|
|
||||||
$stride_w)>;
|
|
||||||
|
|
||||||
// Quantize the DepthwiseConv2DOp if the input and weight are dequantized. The
|
|
||||||
// scale of the bias input is determined by the scales of input and weight
|
|
||||||
// operands.
|
|
||||||
def : Pat<(TFL_QuantizeOp
|
|
||||||
(TFL_DepthwiseConv2DOp
|
|
||||||
(TFL_DequantizeOp $in),
|
|
||||||
(TFL_DequantizeOp $weight),
|
|
||||||
(TFL_DequantizeOp $bias),
|
|
||||||
$dilation_h_factor,
|
|
||||||
$dilation_w_factor,
|
|
||||||
$fused_activation_function,
|
|
||||||
$padding,
|
|
||||||
$stride_h,
|
|
||||||
$stride_w,
|
|
||||||
$multiplier),
|
|
||||||
$output_type),
|
|
||||||
(TFL_DepthwiseConv2DOp
|
|
||||||
$in,
|
|
||||||
$weight,
|
|
||||||
$bias,
|
|
||||||
$dilation_h_factor,
|
|
||||||
$dilation_w_factor,
|
|
||||||
$fused_activation_function,
|
|
||||||
$padding,
|
|
||||||
$stride_h,
|
|
||||||
$stride_w,
|
|
||||||
$multiplier)>;
|
|
||||||
|
|
||||||
// Quantize the ReshapeOp if the input is dequantized and output is quantized.
|
|
||||||
// The pre-quantize pass can guarantee both quantization parameters are the
|
|
||||||
// same.
|
|
||||||
def : Pat<(TFL_QuantizeOp (TFL_ReshapeOp (TFL_DequantizeOp $in)), $output_type),
|
|
||||||
(TFL_ReshapeOp $in)>;
|
|
||||||
|
|
||||||
// Quantize the ReshapeOp if the input is dequantized and output is quantized.
|
|
||||||
// The pre-quantize pass has set the output quantization parameters to a
|
|
||||||
// pre-defined value.
|
|
||||||
def : Pat<(TFL_QuantizeOp (TFL_SoftmaxOp (TFL_DequantizeOp $in), $beta),
|
|
||||||
$output_type),
|
|
||||||
(TFL_SoftmaxOp $in, $beta)>;
|
|
||||||
|
|
||||||
// Quantize the AveragePool2DOp if the input is dequantized and output is
|
|
||||||
// quantized. The pre-quantize pass can guarantee both quantization parameters
|
|
||||||
// are the same.
|
|
||||||
def : Pat<(TFL_QuantizeOp (TFL_AveragePool2DOp (TFL_DequantizeOp $in),
|
|
||||||
$filter_height, $filter_width, $fused_activation_function,
|
|
||||||
$padding, $stride_h, $stride_w), $output_type),
|
|
||||||
(TFL_AveragePool2DOp $in,
|
|
||||||
$filter_height, $filter_width, $fused_activation_function,
|
|
||||||
$padding, $stride_h, $stride_w)>;
|
|
||||||
|
|
||||||
// Quantize the MaxPool2DOp if the input is dequantized and output is
|
|
||||||
// quantized. The pre-quantize pass can guarantee both quantization parameters
|
|
||||||
// are the same.
|
|
||||||
def : Pat<(TFL_QuantizeOp (TFL_MaxPool2DOp (TFL_DequantizeOp $in),
|
|
||||||
$padding, $stride_w, $tride_h, $stride_width, $stride_height,
|
|
||||||
$fused_activation_function), $output_type),
|
|
||||||
(TFL_MaxPool2DOp $in,
|
|
||||||
$padding, $stride_w, $tride_h, $stride_width, $stride_height,
|
|
||||||
$fused_activation_function)>;
|
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||||
@ -121,7 +122,7 @@ struct RequantizeState {
|
|||||||
//
|
//
|
||||||
class QuantizationDriver {
|
class QuantizationDriver {
|
||||||
public:
|
public:
|
||||||
explicit QuantizationDriver(Function fn) : builder_(fn.getBody()) {}
|
explicit QuantizationDriver(FuncOp fn) : builder_(fn.getBody()) {}
|
||||||
|
|
||||||
// The entry point of the quantization parameters propagation.
|
// The entry point of the quantization parameters propagation.
|
||||||
void Run();
|
void Run();
|
||||||
@ -706,7 +707,7 @@ void QuantizationDriver::Run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ApplyQuantizationParamsPropagation(mlir::Function func) {
|
void ApplyQuantizationParamsPropagation(mlir::FuncOp func) {
|
||||||
QuantizationDriver(func).Run();
|
QuantizationDriver(func).Run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,12 +20,66 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_
|
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_
|
||||||
|
|
||||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||||
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TFL {
|
namespace TFL {
|
||||||
|
|
||||||
|
// A generic rewrite pattern which matches any N-in-1-out operations with
|
||||||
|
// quantization parameters propagated to all the operands and results values.
|
||||||
|
// The quantization parameters are annotated by the Q/DQ op pairs. Each matched
|
||||||
|
// pattern are rewritten by its quantized alternatives.
|
||||||
|
//
|
||||||
|
// This pattern assumes all the matched ops are quantizable. This assumption is
|
||||||
|
// always right, except when a "Q" op is used as a requantize op. For non-"Q"
|
||||||
|
// ops, quantization parameters should be propagated to their result.
|
||||||
|
//
|
||||||
|
// This pattern only matches ops which only have one result.
|
||||||
|
template <typename Q, typename DQ>
|
||||||
|
struct GenericFullQuantizationPattern : public RewritePattern {
|
||||||
|
explicit GenericFullQuantizationPattern(MLIRContext* context)
|
||||||
|
: RewritePattern(Q::getOperationName(), 1, context) {}
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(Operation* op,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
if (op->getNumResults() != 1) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
auto quantize_op = cast<Q>(op);
|
||||||
|
auto quantized_op = quantize_op.input()->getDefiningOp();
|
||||||
|
// If it is a block argument, requantize op, or has more than one result, we
|
||||||
|
// shouldn't rewrite this op.
|
||||||
|
if (!quantized_op || llvm::isa<Q>(quantized_op) ||
|
||||||
|
llvm::isa<DQ>(quantized_op) || quantized_op->getNumResults() != 1) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect all the quantized inputs and "clone" the matched op by these
|
||||||
|
// inputs.
|
||||||
|
SmallVector<Value*, 4> inputs;
|
||||||
|
inputs.reserve(quantized_op->getNumOperands());
|
||||||
|
for (int i = 0, e = quantized_op->getNumOperands(); i != e; ++i) {
|
||||||
|
auto* operand = quantized_op->getOperand(i);
|
||||||
|
if (auto op_inst = dyn_cast_or_null<DQ>(operand->getDefiningOp())) {
|
||||||
|
inputs.push_back(op_inst.input());
|
||||||
|
} else {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Use OpBuilder so we can use op name to create the new op.
|
||||||
|
OpBuilder builder(quantized_op);
|
||||||
|
OperationState new_state(
|
||||||
|
quantized_op->getLoc(), quantized_op->getName().getStringRef(), inputs,
|
||||||
|
op->getResult(0)->getType(), quantized_op->getAttrs());
|
||||||
|
Operation* new_op = builder.createOperation(new_state);
|
||||||
|
rewriter.replaceOp(op, {new_op->getResult(0)});
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Converts the min/max/storage_type/narrow_range information to a
|
// Converts the min/max/storage_type/narrow_range information to a
|
||||||
// QuantizedType, and then returns the attribute containing the QuantizedType.
|
// QuantizedType, and then returns the attribute containing the QuantizedType.
|
||||||
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, FloatAttr min,
|
TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, FloatAttr min,
|
||||||
@ -62,7 +116,7 @@ quant::QuantizedType GetUniformQuantizedTypeForBias(
|
|||||||
// quantization parameters are stored as adjacent quantize and dequantize ops
|
// quantization parameters are stored as adjacent quantize and dequantize ops
|
||||||
// and the propagation results are materialized by inserting pairs of quantize
|
// and the propagation results are materialized by inserting pairs of quantize
|
||||||
// and dequantize ops to this function.
|
// and dequantize ops to this function.
|
||||||
void ApplyQuantizationParamsPropagation(mlir::Function func);
|
void ApplyQuantizationParamsPropagation(mlir::FuncOp func);
|
||||||
|
|
||||||
} // end namespace TFL
|
} // end namespace TFL
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
@ -103,6 +103,7 @@ cc_library(
|
|||||||
"transforms/generated_optimize.inc",
|
"transforms/generated_optimize.inc",
|
||||||
"transforms/optimize.cc",
|
"transforms/optimize.cc",
|
||||||
"transforms/raise_control_flow.cc",
|
"transforms/raise_control_flow.cc",
|
||||||
|
"translate/control_to_executor_dialect.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"ir/control_flow_ops.h",
|
"ir/control_flow_ops.h",
|
||||||
@ -280,6 +281,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_proto_cc",
|
"//tensorflow/core:protos_all_proto_cc",
|
||||||
"@local_config_mlir//:IR",
|
"@local_config_mlir//:IR",
|
||||||
|
"@local_config_mlir//:StandardOps",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -471,6 +473,7 @@ cc_library(
|
|||||||
"@llvm//:support",
|
"@llvm//:support",
|
||||||
"@local_config_mlir//:IR",
|
"@local_config_mlir//:IR",
|
||||||
"@local_config_mlir//:Parser",
|
"@local_config_mlir//:Parser",
|
||||||
|
"@local_config_mlir//:Pass",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -694,7 +694,7 @@ void Print(NextIterationSinkOp next_iteration, OpAsmPrinter *p) {
|
|||||||
*p << next_iteration.getOperationName() << " [";
|
*p << next_iteration.getOperationName() << " [";
|
||||||
p->printOperand(next_iteration.getOperand(0));
|
p->printOperand(next_iteration.getOperand(0));
|
||||||
*p << "] ";
|
*p << "] ";
|
||||||
p->printOperand(next_iteration.getOperand(1));
|
p->printOperands(llvm::drop_begin(next_iteration.getOperands(), 1));
|
||||||
*p << " : " << next_iteration.getOperand(1)->getType();
|
*p << " : " << next_iteration.getOperand(1)->getType();
|
||||||
p->printOptionalAttrDict(next_iteration.getAttrs());
|
p->printOptionalAttrDict(next_iteration.getAttrs());
|
||||||
}
|
}
|
||||||
|
@ -250,25 +250,6 @@ def TfExecutor_SwitchOp : TfExecutor_Op<"Switch",
|
|||||||
TfeControlType: $control
|
TfeControlType: $control
|
||||||
);
|
);
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
|
||||||
"Builder *builder, OperationState *result, ArrayRef<Value *> operands = {}",
|
|
||||||
[{
|
|
||||||
assert(operands.size() >= 2 && "tf_executor.Switch builder expects at "
|
|
||||||
"least two operands");
|
|
||||||
return build(builder, result, operands[0], operands[1], operands.drop_front(2));
|
|
||||||
}]>,
|
|
||||||
OpBuilder<
|
|
||||||
"Builder *builder, OperationState *result, Value *data, Value *predicate, ArrayRef<Value *> controls = {}",
|
|
||||||
[{
|
|
||||||
Type dataTy = data->getType();
|
|
||||||
Type controlTy = ControlType::get(builder->getContext());
|
|
||||||
result->types = { dataTy, dataTy, controlTy };
|
|
||||||
result->operands.push_back(data);
|
|
||||||
result->operands.push_back(predicate);
|
|
||||||
result->operands.insert(result->operands.end(), controls.begin(), controls.end());
|
|
||||||
}]>
|
|
||||||
];
|
|
||||||
|
|
||||||
let verifier = ?;
|
let verifier = ?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -311,7 +292,6 @@ def TfExecutor_SwitchNOp :
|
|||||||
Variadic<AnyType>:$outputs,
|
Variadic<AnyType>:$outputs,
|
||||||
TfeControlType: $control
|
TfeControlType: $control
|
||||||
);
|
);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def TfExecutor_MergeOp : TfExecutor_Op<"Merge", [NoSideEffect, ControlOperandsAfterAllData]> {
|
def TfExecutor_MergeOp : TfExecutor_Op<"Merge", [NoSideEffect, ControlOperandsAfterAllData]> {
|
||||||
@ -346,18 +326,6 @@ def TfExecutor_MergeOp : TfExecutor_Op<"Merge", [NoSideEffect, ControlOperandsAf
|
|||||||
TensorOf<[I32]>:$valueIndex,
|
TensorOf<[I32]>:$valueIndex,
|
||||||
TfeControlType:$control
|
TfeControlType:$control
|
||||||
);
|
);
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
|
||||||
"Builder *builder, OperationState *result, ArrayRef<Value *> operands",
|
|
||||||
[{
|
|
||||||
assert(operands.size() >= 1 && "tf_executor.Merge builder expects at "
|
|
||||||
"least one operand");
|
|
||||||
Type data_type = operands[0]->getType();
|
|
||||||
Type control_type = ControlType::get(builder->getContext());
|
|
||||||
result->types = { data_type, builder->getIntegerType(32), control_type};
|
|
||||||
result->operands.append(operands.begin(), operands.end());
|
|
||||||
}]>
|
|
||||||
];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def TfExecutor_EnterOp : TfExecutor_Op<"Enter",
|
def TfExecutor_EnterOp : TfExecutor_Op<"Enter",
|
||||||
@ -408,19 +376,6 @@ def TfExecutor_EnterOp : TfExecutor_Op<"Enter",
|
|||||||
);
|
);
|
||||||
|
|
||||||
let verifier = ?;
|
let verifier = ?;
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
|
||||||
"Builder *builder, OperationState *result, ArrayRef<Value *> operands",
|
|
||||||
[{
|
|
||||||
assert(operands.size() >= 1 && "tf_executor.Enter builder "
|
|
||||||
"expects at least one operand");
|
|
||||||
result->operands.append(operands.begin(), operands.end());
|
|
||||||
|
|
||||||
Type control_type = ControlType::get(builder->getContext());
|
|
||||||
result->types.push_back(operands[0]->getType());
|
|
||||||
result->types.push_back(control_type);
|
|
||||||
}]>
|
|
||||||
];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", [NoSideEffect]> {
|
def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", [NoSideEffect]> {
|
||||||
@ -472,12 +427,14 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", [No
|
|||||||
);
|
);
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"Builder *builder, OperationState *result, Type resultTy, ArrayRef<Value *> controlInputs = {}",
|
"Builder *builder, OperationState *result, Type result_type, "
|
||||||
|
"ArrayRef<Value *> control_inputs = {}, ArrayRef<NamedAttribute> attributes = {}",
|
||||||
[{
|
[{
|
||||||
Type tokenTy = TokenType::get(builder->getContext());
|
Type token_type = TokenType::get(builder->getContext());
|
||||||
Type controlTy = ControlType::get(builder->getContext());
|
Type control_type = ControlType::get(builder->getContext());
|
||||||
result->types = { resultTy, tokenTy, controlTy };
|
result->types = { result_type, token_type, control_type };
|
||||||
result->operands.append(controlInputs.begin(), controlInputs.end());
|
result->operands.append(control_inputs.begin(), control_inputs.end());
|
||||||
|
result->attributes.append(attributes.begin(), attributes.end());
|
||||||
}]>
|
}]>
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
@ -527,6 +484,19 @@ def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink"> {
|
|||||||
// Optional extra control inputs.
|
// Optional extra control inputs.
|
||||||
Variadic<TfeControlType>:$controlInputs
|
Variadic<TfeControlType>:$controlInputs
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"Builder *builder, OperationState *result, Value *token, "
|
||||||
|
"ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes = {}",
|
||||||
|
[{
|
||||||
|
assert(operands.size() >= 1 && "tf_executor.NextIteration.Sink builder "
|
||||||
|
"expects at least one operand");
|
||||||
|
result->operands.push_back(token);
|
||||||
|
result->operands.insert(result->operands.end(), operands.begin(),
|
||||||
|
operands.end());
|
||||||
|
result->attributes.append(attributes.begin(), attributes.end());
|
||||||
|
}]>
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TfExecutor_ExitOp : TfExecutor_Op<"Exit",
|
def TfExecutor_ExitOp : TfExecutor_Op<"Exit",
|
||||||
@ -590,6 +560,20 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger", [NoSideEffect]
|
|||||||
);
|
);
|
||||||
|
|
||||||
let verifier = ?;
|
let verifier = ?;
|
||||||
|
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"Builder *builder, OperationState *result, "
|
||||||
|
"ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes = {}",
|
||||||
|
[{
|
||||||
|
assert(operands.size() >= 1 && "tf_executor.ControlTrigger builder "
|
||||||
|
"expects at least one operand");
|
||||||
|
result->operands.insert(result->operands.end(), operands.begin(),
|
||||||
|
operands.end());
|
||||||
|
Type control_type = ControlType::get(builder->getContext());
|
||||||
|
result->types = {control_type};
|
||||||
|
result->attributes.append(attributes.begin(), attributes.end());
|
||||||
|
}]>
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond", [NoSideEffect]> {
|
def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond", [NoSideEffect]> {
|
||||||
|
@ -79,6 +79,12 @@ def TF_AddNOp : TF_Op<"AddN", [Commutative, NoSideEffect]> {
|
|||||||
let summary = "Add all input tensors element wise.";
|
let summary = "Add all input tensors element wise.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
Inputs must be of same size and shape.
|
||||||
|
|
||||||
|
```python
|
||||||
|
x = [9, 7, 10]
|
||||||
|
tf.math.add_n(x) ==> 26
|
||||||
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@ -467,6 +473,15 @@ def TF_CosOp : TF_Op<"Cos", [NoSideEffect, SameOperandsAndResultType]> {
|
|||||||
let summary = "Computes cos of x element-wise.";
|
let summary = "Computes cos of x element-wise.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
Given an input tensor, this function computes cosine of every
|
||||||
|
element in the tensor. Input range is `(-inf, inf)` and
|
||||||
|
output range is `[-1,1]`. If input lies outside the boundary, `nan`
|
||||||
|
is returned.
|
||||||
|
|
||||||
|
```python
|
||||||
|
x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 200, 10000, float("inf")])
|
||||||
|
tf.math.cos(x) ==> [nan -0.91113025 0.87758255 0.5403023 0.36235774 0.48718765 -0.95215535 nan]
|
||||||
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@ -1027,6 +1042,43 @@ Invert (flip) each bit of supported types; for example, type `uint8` value 01010
|
|||||||
let description = [{
|
let description = [{
|
||||||
Flip each bit of supported types. For example, type `int8` (decimal 2) binary 00000010 becomes (decimal -3) binary 11111101.
|
Flip each bit of supported types. For example, type `int8` (decimal 2) binary 00000010 becomes (decimal -3) binary 11111101.
|
||||||
This operation is performed on each element of the tensor argument `x`.
|
This operation is performed on each element of the tensor argument `x`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.ops import bitwise_ops
|
||||||
|
|
||||||
|
# flip 2 (00000010) to -3 (11111101)
|
||||||
|
tf.assert_equal(-3, bitwise_ops.invert(2))
|
||||||
|
|
||||||
|
dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
|
||||||
|
dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64]
|
||||||
|
|
||||||
|
inputs = [0, 5, 3, 14]
|
||||||
|
for dtype in dtype_list:
|
||||||
|
# Because of issues with negative numbers, let's test this indirectly.
|
||||||
|
# 1. invert(a) and a = 0
|
||||||
|
# 2. invert(a) or a = invert(0)
|
||||||
|
input_tensor = tf.constant([0, 5, 3, 14], dtype=dtype)
|
||||||
|
not_a_and_a, not_a_or_a, not_0 = [bitwise_ops.bitwise_and(
|
||||||
|
input_tensor, bitwise_ops.invert(input_tensor)),
|
||||||
|
bitwise_ops.bitwise_or(
|
||||||
|
input_tensor, bitwise_ops.invert(input_tensor)),
|
||||||
|
bitwise_ops.invert(
|
||||||
|
tf.constant(0, dtype=dtype))]
|
||||||
|
|
||||||
|
expected = tf.constant([0, 0, 0, 0], dtype=tf.float32)
|
||||||
|
tf.assert_equal(tf.cast(not_a_and_a, tf.float32), expected)
|
||||||
|
|
||||||
|
expected = tf.cast([not_0] * 4, tf.float32)
|
||||||
|
tf.assert_equal(tf.cast(not_a_or_a, tf.float32), expected)
|
||||||
|
|
||||||
|
# For unsigned dtypes let's also check the result directly.
|
||||||
|
if dtype.is_unsigned:
|
||||||
|
inverted = bitwise_ops.invert(input_tensor)
|
||||||
|
expected = tf.constant([dtype.max - x for x in inputs], dtype=tf.float32)
|
||||||
|
tf.assert_equal(tf.cast(inverted, tf.float32), tf.cast(expected, tf.float32))
|
||||||
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@ -1348,6 +1400,52 @@ def TF_MinimumOp : TF_Op<"Minimum", [Broadcastable, NoSideEffect]>,
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_MirrorPadOp : TF_Op<"MirrorPad", [NoSideEffect]> {
|
||||||
|
let summary = "Pads a tensor with mirrored values.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
This operation pads a `input` with mirrored values according to the `paddings`
|
||||||
|
you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is
|
||||||
|
the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
|
||||||
|
how many values to add before the contents of `input` in that dimension, and
|
||||||
|
`paddings[D, 1]` indicates how many values to add after the contents of `input`
|
||||||
|
in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater
|
||||||
|
than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true
|
||||||
|
(if false, respectively).
|
||||||
|
|
||||||
|
The padded size of each dimension D of the output is:
|
||||||
|
|
||||||
|
`paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
# 't' is [[1, 2, 3], [4, 5, 6]].
|
||||||
|
# 'paddings' is [[1, 1]], [2, 2]].
|
||||||
|
# 'mode' is SYMMETRIC.
|
||||||
|
# rank of 't' is 2.
|
||||||
|
pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2]
|
||||||
|
[2, 1, 1, 2, 3, 3, 2]
|
||||||
|
[5, 4, 4, 5, 6, 6, 5]
|
||||||
|
[5, 4, 4, 5, 6, 6, 5]]
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_Tensor:$input,
|
||||||
|
TF_I32OrI64Tensor:$paddings,
|
||||||
|
|
||||||
|
TF_AnyStrAttrOf<["REFLECT", "SYMMETRIC"]>:$mode
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Tensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_MulOp : TF_Op<"Mul", [Broadcastable, Commutative, NoSideEffect]>,
|
def TF_MulOp : TF_Op<"Mul", [Broadcastable, Commutative, NoSideEffect]>,
|
||||||
WithBroadcastableBinOpBuilder {
|
WithBroadcastableBinOpBuilder {
|
||||||
let summary = "Returns x * y element-wise.";
|
let summary = "Returns x * y element-wise.";
|
||||||
@ -2194,9 +2292,17 @@ Specifically, `y = 1 / (1 + exp(-x))`.
|
|||||||
}
|
}
|
||||||
|
|
||||||
def TF_SinOp : TF_Op<"Sin", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_SinOp : TF_Op<"Sin", [NoSideEffect, SameOperandsAndResultType]> {
|
||||||
let summary = "Computes sin of x element-wise.";
|
let summary = "Computes sine of x element-wise.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
Given an input tensor, this function computes sine of every
|
||||||
|
element in the tensor. Input range is `(-inf, inf)` and
|
||||||
|
output range is `[-1,1]`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
x = tf.constant([-float("inf"), -9, -0.5, 1, 1.2, 200, 10, float("inf")])
|
||||||
|
tf.math.sin(x) ==> [nan -0.4121185 -0.47942555 0.84147096 0.9320391 -0.87329733 -0.54402107 nan]
|
||||||
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@ -2591,6 +2697,31 @@ retained with length 1.
|
|||||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_TanhOp : TF_Op<"Tanh", [NoSideEffect, SameOperandsAndResultType]> {
|
||||||
|
let summary = "Computes hyperbolic tangent of `x` element-wise.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Given an input tensor, this function computes hyperbolic tangent of every
|
||||||
|
element in the tensor. Input range is `[-inf, inf]` and
|
||||||
|
output range is `[-1,1]`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
x = tf.constant([-float("inf"), -5, -0.5, 1, 1.2, 2, 3, float("inf")])
|
||||||
|
tf.math.tanh(x) ==> [-1. -0.99990916 -0.46211717 0.7615942 0.8336547 0.9640276 0.9950547 1.]
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_FpOrComplexTensor:$x
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_FpOrComplexTensor:$y
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_TensorListFromTensorOp : TF_Op<"TensorListFromTensor", [NoSideEffect]> {
|
def TF_TensorListFromTensorOp : TF_Op<"TensorListFromTensor", [NoSideEffect]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Creates a TensorList which, when stacked, has the value of `tensor`.
|
Creates a TensorList which, when stacked, has the value of `tensor`.
|
||||||
|
@ -275,18 +275,18 @@ static LogicalResult Verify(FusedBatchNormOp op) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult IfOp::verify() {
|
LogicalResult IfOp::verify() {
|
||||||
auto thenAttr = getAttrOfType<FunctionAttr>("then_branch");
|
auto thenAttr = getAttrOfType<SymbolRefAttr>("then_branch");
|
||||||
if (!thenAttr) return emitOpError("requires then_branch attribute");
|
if (!thenAttr) return emitOpError("requires then_branch attribute");
|
||||||
|
|
||||||
auto elseAttr = getAttrOfType<FunctionAttr>("else_branch");
|
auto elseAttr = getAttrOfType<SymbolRefAttr>("else_branch");
|
||||||
if (!elseAttr) return emitOpError("requires else_branch attribute");
|
if (!elseAttr) return emitOpError("requires else_branch attribute");
|
||||||
|
|
||||||
auto module = getParentOfType<Module>();
|
auto module = getParentOfType<ModuleOp>();
|
||||||
auto thenFn = module.getNamedFunction(thenAttr.getValue());
|
auto thenFn = module.lookupSymbol<FuncOp>(thenAttr.getValue());
|
||||||
if (!thenFn)
|
if (!thenFn)
|
||||||
return emitOpError("then_branch refers to an undefined function : ")
|
return emitOpError("then_branch refers to an undefined function : ")
|
||||||
<< thenAttr;
|
<< thenAttr;
|
||||||
auto elseFn = module.getNamedFunction(elseAttr.getValue());
|
auto elseFn = module.lookupSymbol<FuncOp>(elseAttr.getValue());
|
||||||
if (!elseFn)
|
if (!elseFn)
|
||||||
return emitOpError("else_branch refers to an undefined function : ")
|
return emitOpError("else_branch refers to an undefined function : ")
|
||||||
<< elseAttr;
|
<< elseAttr;
|
||||||
@ -627,9 +627,10 @@ OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult Verify(SoftmaxOp op) {
|
static LogicalResult Verify(SoftmaxOp op) {
|
||||||
if (!IsOfRankOrUnranked(op.logits(), 2))
|
if (!IsOfRankOrUnranked(op.logits(), 1) &&
|
||||||
return op.emitOpError("requires operand to be 2D tensor");
|
!IsOfRankOrUnranked(op.logits(), 2)) {
|
||||||
|
return op.emitOpError("requires operand to be 1D/2D tensor");
|
||||||
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -727,20 +728,20 @@ void TruncateDivOp::getCanonicalizationPatterns(
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult WhileOp::verify() {
|
LogicalResult WhileOp::verify() {
|
||||||
auto condAttr = getAttrOfType<FunctionAttr>("cond");
|
auto condAttr = getAttrOfType<SymbolRefAttr>("cond");
|
||||||
if (!condAttr) return emitOpError("requires cond attribute");
|
if (!condAttr) return emitOpError("requires cond attribute");
|
||||||
|
|
||||||
auto module = getParentOfType<Module>();
|
auto module = getParentOfType<ModuleOp>();
|
||||||
auto condFn = module.getNamedFunction(condAttr.getValue());
|
auto condFn = module.lookupSymbol<FuncOp>(condAttr.getValue());
|
||||||
auto condFuncType = condFn.getType();
|
auto condFuncType = condFn.getType();
|
||||||
|
|
||||||
// Verify that the cond function has exactly one result.
|
// Verify that the cond function has exactly one result.
|
||||||
if (condFuncType.getNumResults() != 1)
|
if (condFuncType.getNumResults() != 1)
|
||||||
return emitOpError("requires cond function to have exactly one result");
|
return emitOpError("requires cond function to have exactly one result");
|
||||||
|
|
||||||
auto bodyAttr = getAttrOfType<FunctionAttr>("body");
|
auto bodyAttr = getAttrOfType<SymbolRefAttr>("body");
|
||||||
if (!bodyAttr) return emitOpError("requires body attribute");
|
if (!bodyAttr) return emitOpError("requires body attribute");
|
||||||
auto bodyFn = module.getNamedFunction(bodyAttr.getValue());
|
auto bodyFn = module.lookupSymbol<FuncOp>(bodyAttr.getValue());
|
||||||
auto bodyFuncType = bodyFn.getType();
|
auto bodyFuncType = bodyFn.getType();
|
||||||
|
|
||||||
SmallVector<Type, 4> operands(getOperandTypes());
|
SmallVector<Type, 4> operands(getOperandTypes());
|
||||||
|
@ -119,11 +119,11 @@ class IfOp : public Op<IfOp, TensorFlowOp, OpTrait::AtLeastNOperands<1>::Impl,
|
|||||||
|
|
||||||
// TODO(b/132271680): This is not following Google naming style
|
// TODO(b/132271680): This is not following Google naming style
|
||||||
StringRef getThen() {
|
StringRef getThen() {
|
||||||
return getAttrOfType<FunctionAttr>("then_branch").getValue();
|
return getAttrOfType<SymbolRefAttr>("then_branch").getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
StringRef getElse() {
|
StringRef getElse() {
|
||||||
return getAttrOfType<FunctionAttr>("else_branch").getValue();
|
return getAttrOfType<SymbolRefAttr>("else_branch").getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult verify();
|
LogicalResult verify();
|
||||||
@ -157,8 +157,12 @@ class WhileOp : public Op<WhileOp, TensorFlowOp, OpTrait::VariadicOperands,
|
|||||||
using Op::Op;
|
using Op::Op;
|
||||||
static StringRef getOperationName() { return "tf.While"; }
|
static StringRef getOperationName() { return "tf.While"; }
|
||||||
|
|
||||||
StringRef getCond() { return getAttrOfType<FunctionAttr>("cond").getValue(); }
|
StringRef getCond() {
|
||||||
StringRef getBody() { return getAttrOfType<FunctionAttr>("body").getValue(); }
|
return getAttrOfType<SymbolRefAttr>("cond").getValue();
|
||||||
|
}
|
||||||
|
StringRef getBody() {
|
||||||
|
return getAttrOfType<SymbolRefAttr>("body").getValue();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult verify();
|
LogicalResult verify();
|
||||||
};
|
};
|
||||||
|
@ -105,8 +105,10 @@ retained with length 1.
|
|||||||
|
|
||||||
// In MLIR, the 'tf.Placeholder.input' instruction is used to capture attributes
|
// In MLIR, the 'tf.Placeholder.input' instruction is used to capture attributes
|
||||||
// of function arguments.
|
// of function arguments.
|
||||||
|
// Note: NoSideEffect trait is not added intentionally to preserve the captured
|
||||||
|
// attributes even if the input is unused.
|
||||||
def TF_PlaceholderInputOp : TF_Op<"Placeholder.input",
|
def TF_PlaceholderInputOp : TF_Op<"Placeholder.input",
|
||||||
[NoSideEffect, SameOperandsAndResultType]> {
|
[SameOperandsAndResultType]> {
|
||||||
let summary = "PlaceholderInput op";
|
let summary = "PlaceholderInput op";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -19,6 +19,10 @@ limitations under the License.
|
|||||||
#ifdef HANDLE_TF_TYPE
|
#ifdef HANDLE_TF_TYPE
|
||||||
|
|
||||||
// class, enumerant, name
|
// class, enumerant, name
|
||||||
|
HANDLE_TF_TYPE(Uint8, UINT8, "uint8")
|
||||||
|
HANDLE_TF_TYPE(Uint16, UINT16, "uint16")
|
||||||
|
HANDLE_TF_TYPE(Uint32, UINT32, "uint32")
|
||||||
|
HANDLE_TF_TYPE(Uint64, UINT64, "uint64")
|
||||||
HANDLE_TF_TYPE(Qint8, QINT8, "qint8")
|
HANDLE_TF_TYPE(Qint8, QINT8, "qint8")
|
||||||
HANDLE_TF_TYPE(Qint16, QINT16, "qint16")
|
HANDLE_TF_TYPE(Qint16, QINT16, "qint16")
|
||||||
HANDLE_TF_TYPE(Qint32, QINT32, "qint32")
|
HANDLE_TF_TYPE(Qint32, QINT32, "qint32")
|
||||||
|
@ -0,0 +1,86 @@
|
|||||||
|
// RUN: tf-opt -tf-control-to-executor-conversion %s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @islands_with_control
|
||||||
|
// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: tensor<*xf32>)
|
||||||
|
func @islands_with_control(tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
^bb0(%0: tensor<*xf32>):
|
||||||
|
%1:2 = "_tf.Identity"(%0) : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
|
||||||
|
%2 = "_tf.Add"(%0, %0, %1#1) : (tensor<*xf32>, tensor<*xf32>, !_tf.control) -> tensor<*xf32>
|
||||||
|
return %2 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-NEXT: %[[GRAPH:[0-9]*]] = tf_executor.graph {
|
||||||
|
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island {
|
||||||
|
// CHECK-NEXT: %{{[0-9]*}} = "tf.Identity"(%[[ARG0]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) {
|
||||||
|
// CHECK-NEXT: %{{[0-9]*}} = "tf.Add"(%[[ARG0]], %[[ARG0]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: tf_executor.fetch %[[ADD]]#0 : tensor<*xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return %[[GRAPH]] : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @LoopTest() {
|
||||||
|
|
||||||
|
func @LoopTest() {
|
||||||
|
%0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
|
||||||
|
%1:2 = "_tf.Enter"(%0#0) {T = "tfdtype$DT_INT32", device = "", frame_name = "while/while_context", is_constant = false, name = "while/Enter", parallel_iterations = 10 : i64} : (tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
||||||
|
%2 = "_tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> !_tf.control
|
||||||
|
%3:2 = "_tf.NextIteration.source"() {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : () -> (tensor<*xi32>, !_tf.control)
|
||||||
|
%4:3 = "_tf.Merge"(%3#0, %1#0) {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<i32>, !_tf.control)
|
||||||
|
%5:2 = "_tf.Const"(%4#2) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
|
||||||
|
%6:2 = "_tf.Less"(%4#0, %5#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control)
|
||||||
|
%7:2 = "_tf.LoopCond"(%6#0) {device = "", name = "while/LoopCond"} : (tensor<*xi1>) -> (tensor<i1>, !_tf.control)
|
||||||
|
%8:3 = "_tf.Switch"(%4#0, %7#0) {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"} : (tensor<*xi32>, tensor<i1>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
|
||||||
|
%9:2 = "_tf.Exit"(%8#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||||
|
%10:2 = "_tf.Identity"(%8#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||||
|
%11:2 = "_tf.Const"(%10#1) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
|
||||||
|
%12:2 = "_tf.Add"(%10#0, %11#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
||||||
|
%13 = "_tf.ControlTrigger"(%2, %12#1, %9#1) {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"} : (!_tf.control, !_tf.control, !_tf.control) -> !_tf.control
|
||||||
|
%14 = "_tf.NextIteration.sink"(%12#0, %13) {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : (tensor<*xi32>, !_tf.control) -> (!_tf.control)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-NEXT: tf_executor.graph {
|
||||||
|
// CHECK-NEXT: %[[CONST:[0-9]*]]:2 = tf_executor.island {
|
||||||
|
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<i32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: %[[ENTER:[0-9]*]]:2 = tf_executor.Enter %[[CONST]]#0 frame "while/while_context" : (tensor<i32>) -> (tensor<*xi32>, !tf_executor.control) {T = "tfdtype$DT_INT32", device = "", name = "while/Enter"}
|
||||||
|
// CHECK-NEXT: %[[NOOP:[0-9]*]] = tf_executor.island {
|
||||||
|
// CHECK-NEXT: "tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> ()
|
||||||
|
// CHECK-NEXT: tf_executor.yield
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: %[[NEXTIT_SRC:[0-9]*]]:3 = tf_executor.NextIteration.Source : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
|
||||||
|
// CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = tf_executor.Merge %[[NEXTIT_SRC]]#0, %[[ENTER]]#0 : tensor<*xi32> {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"}
|
||||||
|
// CHECK-NEXT: %[[CONST_LESS:[0-9]*]]:2 = tf_executor.island(%[[MERGE]]#2) {
|
||||||
|
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<i32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: %[[LESS:[0-9]*]]:2 = tf_executor.island {
|
||||||
|
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Less"(%[[MERGE]]#0, %[[CONST_LESS]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||||
|
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<*xi1>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: %[[COND:[0-9]*]]:2 = tf_executor.LoopCond %[[LESS:[0-9]*]]#0 : (tensor<*xi1>) -> (tensor<i1>, !tf_executor.control) {device = "", name = "while/LoopCond"}
|
||||||
|
// CHECK-NEXT: %[[SWITCH:[0-9]*]]:3 = tf_executor.Switch %[[MERGE]]#0, %[[COND]]#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"}
|
||||||
|
// CHECK-NEXT: %[[EXIT:[0-9]*]]:2 = tf_executor.Exit %[[SWITCH]]#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"}
|
||||||
|
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = tf_executor.island {
|
||||||
|
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Identity"(%[[SWITCH]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> tensor<*xi32>
|
||||||
|
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<*xi32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: %[[CONST_ADD:[0-9]*]]:2 = tf_executor.island(%[[IDENTITY]]#1) {
|
||||||
|
// CHECK-NEXT: %{{[a-z0-9]*}} = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK-NEXT: tf_executor.yield %{{[a-z0-9]*}} : tensor<i32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = tf_executor.island {
|
||||||
|
// CHECK-NEXT: %{{[0-9]*}} = "tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||||
|
// CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xi32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %2, %12#1, %9#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"}
|
||||||
|
// CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC]]#1] %[[ADD]]#0, %[[CT]] : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
|
||||||
|
// CHECK-NEXT: tf_executor.fetch
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return
|
@ -2176,14 +2176,12 @@ versions {
|
|||||||
# CHECK: %73:2 = "_tf.mul_2_Da30D05wlPU0"(%58#0, %72#0, %47#1) {_tpu_replicate = "cluster", device = "", name = "while/mul_2_Da30D05wlPU"} : (tensor<*xf32>, tensor<*xf32>, !_tf.control) -> (tensor<*xf32>, !_tf.control)
|
# CHECK: %73:2 = "_tf.mul_2_Da30D05wlPU0"(%58#0, %72#0, %47#1) {_tpu_replicate = "cluster", device = "", name = "while/mul_2_Da30D05wlPU"} : (tensor<*xf32>, tensor<*xf32>, !_tf.control) -> (tensor<*xf32>, !_tf.control)
|
||||||
# CHECK: return
|
# CHECK: return
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
# CHECK-EMPTY:
|
|
||||||
# CHECK: func @less_than_5_If8q4vKg9jA0(%arg0: tensor<*xf32>) -> tensor<*xi1>
|
# CHECK: func @less_than_5_If8q4vKg9jA0(%arg0: tensor<*xf32>) -> tensor<*xi1>
|
||||||
# CHECK-NEXT: attributes {tf._noinline = true} {
|
# CHECK-NEXT: attributes {tf._noinline = true} {
|
||||||
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Less/y", value = dense<5.000000e+00> : tensor<f32>} : () -> (tensor<f32>, !_tf.control)
|
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Less/y", value = dense<5.000000e+00> : tensor<f32>} : () -> (tensor<f32>, !_tf.control)
|
||||||
# CHECK-NEXT: %1:2 = "_tf.Less"(%arg0, %0#0) {T = "tfdtype$DT_FLOAT", device = "", name = "Less"} : (tensor<*xf32>, tensor<f32>) -> (tensor<*xi1>, !_tf.control)
|
# CHECK-NEXT: %1:2 = "_tf.Less"(%arg0, %0#0) {T = "tfdtype$DT_FLOAT", device = "", name = "Less"} : (tensor<*xf32>, tensor<f32>) -> (tensor<*xi1>, !_tf.control)
|
||||||
# CHECK-NEXT: return %1#0 : tensor<*xi1>
|
# CHECK-NEXT: return %1#0 : tensor<*xi1>
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
# CHECK-EMPTY:
|
|
||||||
# CHECK: func @mul_2_Da30D05wlPU0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32>
|
# CHECK: func @mul_2_Da30D05wlPU0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32>
|
||||||
# CHECK-NEXT: attributes {tf._noinline = true} {
|
# CHECK-NEXT: attributes {tf._noinline = true} {
|
||||||
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "mul/y", value = dense<2.000000e+00> : tensor<1x1xf32>} : () -> (tensor<1x1xf32>, !_tf.control)
|
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "mul/y", value = dense<2.000000e+00> : tensor<1x1xf32>} : () -> (tensor<1x1xf32>, !_tf.control)
|
||||||
|
@ -91,8 +91,7 @@ versions {
|
|||||||
# CHECK-NEXT: %0:2 = "_tf.PartitionedCall"() {Tin = [], Tout = ["tfdtype$DT_INT32"], config = "", config_proto = "", device = "", executor_type = "", f = @foo0, name = "PartitionedCall"} : () -> (tensor<*xi32>, !_tf.control)
|
# CHECK-NEXT: %0:2 = "_tf.PartitionedCall"() {Tin = [], Tout = ["tfdtype$DT_INT32"], config = "", config_proto = "", device = "", executor_type = "", f = @foo0, name = "PartitionedCall"} : () -> (tensor<*xi32>, !_tf.control)
|
||||||
# CHECK-NEXT: return
|
# CHECK-NEXT: return
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
# CHECK-EMPTY:
|
# CHECK: func @foo0() -> tensor<i32>
|
||||||
# CHECK-NEXT: func @foo0() -> tensor<i32>
|
|
||||||
# CHECK-NEXT: attributes {tf.experimental_ints_on_device = true} {
|
# CHECK-NEXT: attributes {tf.experimental_ints_on_device = true} {
|
||||||
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<5> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
|
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<5> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
|
||||||
# CHECK-NEXT: %1:2 = "_tf.Identity"(%0#0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor<i32>) -> (tensor<i32>, !_tf.control)
|
# CHECK-NEXT: %1:2 = "_tf.Identity"(%0#0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor<i32>) -> (tensor<i32>, !_tf.control)
|
||||||
|
@ -157,13 +157,11 @@ versions {
|
|||||||
# CHECK-NEXT: %1:2 = "_tf.Case"(%0#0) {Tin = [], Tout = ["tfdtype$DT_FLOAT"], branches = [@foo0, @bar0], device = "", name = "Case", output_shapes = []} : (tensor<i32>) -> (tensor<*xf32>, !_tf.control)
|
# CHECK-NEXT: %1:2 = "_tf.Case"(%0#0) {Tin = [], Tout = ["tfdtype$DT_FLOAT"], branches = [@foo0, @bar0], device = "", name = "Case", output_shapes = []} : (tensor<i32>) -> (tensor<*xf32>, !_tf.control)
|
||||||
# CHECK-NEXT: return
|
# CHECK-NEXT: return
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
# CHECK-EMPTY:
|
# CHECK: func @foo0() -> tensor<10xf32> {
|
||||||
# CHECK-NEXT: func @foo0() -> tensor<10xf32> {
|
|
||||||
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "const_1", value = dense<1.000000e+00> : tensor<10xf32>} : () -> (tensor<10xf32>, !_tf.control)
|
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "const_1", value = dense<1.000000e+00> : tensor<10xf32>} : () -> (tensor<10xf32>, !_tf.control)
|
||||||
# CHECK-NEXT: return %0#0 : tensor<10xf32>
|
# CHECK-NEXT: return %0#0 : tensor<10xf32>
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
# CHECK-EMPTY:
|
# CHECK: func @bar0() -> tensor<10xf32> {
|
||||||
# CHECK-NEXT: func @bar0() -> tensor<10xf32> {
|
|
||||||
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "const_2", value = dense<2.000000e+00> : tensor<10xf32>} : () -> (tensor<10xf32>, !_tf.control)
|
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "const_2", value = dense<2.000000e+00> : tensor<10xf32>} : () -> (tensor<10xf32>, !_tf.control)
|
||||||
# CHECK-NEXT: return %0#0 : tensor<10xf32>
|
# CHECK-NEXT: return %0#0 : tensor<10xf32>
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
|
@ -526,16 +526,13 @@ versions {
|
|||||||
# CHECK-NEXT: %18:2 = "_tf.Identity"(%17#0, %6) {T = "tfdtype$DT_INT32", device = "", name = "output_1_shard_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
# CHECK-NEXT: %18:2 = "_tf.Identity"(%17#0, %6) {T = "tfdtype$DT_INT32", device = "", name = "output_1_shard_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
||||||
# CHECK-NEXT: return
|
# CHECK-NEXT: return
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
# CHECK-EMPTY:
|
# CHECK: func @cond_false0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
|
||||||
# CHECK-NEXT: func @cond_false0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
|
|
||||||
# CHECK-NEXT: %0:2 = "_tf.Identity"(%arg0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
# CHECK-NEXT: %0:2 = "_tf.Identity"(%arg0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||||
# CHECK-NEXT: %1:2 = "_tf.Identity"(%arg1) {T = "tfdtype$DT_INT32", device = "", name = "Identity_1"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
# CHECK-NEXT: %1:2 = "_tf.Identity"(%arg1) {T = "tfdtype$DT_INT32", device = "", name = "Identity_1"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||||
# CHECK-NEXT: return %1#0, %0#0 : tensor<*xi32>, tensor<*xi32>
|
# CHECK-NEXT: return %1#0, %0#0 : tensor<*xi32>, tensor<*xi32>
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
# CHECK-EMPTY:
|
# CHECK: func @cond_true0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
|
||||||
# CHECK-NEXT: func @cond_true0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
|
|
||||||
# CHECK-NEXT: %0:2 = "_tf.Identity"(%arg0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
# CHECK-NEXT: %0:2 = "_tf.Identity"(%arg0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||||
# CHECK-NEXT: %1:2 = "_tf.Identity"(%arg1) {T = "tfdtype$DT_INT32", device = "", name = "Identity_1"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
# CHECK-NEXT: %1:2 = "_tf.Identity"(%arg1) {T = "tfdtype$DT_INT32", device = "", name = "Identity_1"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||||
# CHECK-NEXT: return %0#0, %1#0 : tensor<*xi32>, tensor<*xi32>
|
# CHECK-NEXT: return %0#0, %1#0 : tensor<*xi32>, tensor<*xi32>
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
# CHECK-EMPTY:
|
|
||||||
|
@ -145,12 +145,10 @@ versions {
|
|||||||
#CHECK-NEXT: %2:2 = "_tf.If"(%0#0, %1#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32"], device = "", else_branch = @get_zeros0, name = "If", output_shapes = [], then_branch = @identity0} : (tensor<*xi1>, tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
#CHECK-NEXT: %2:2 = "_tf.If"(%0#0, %1#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32"], device = "", else_branch = @get_zeros0, name = "If", output_shapes = [], then_branch = @identity0} : (tensor<*xi1>, tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||||
#CHECK-NEXT: return
|
#CHECK-NEXT: return
|
||||||
#CHECK-NEXT: }
|
#CHECK-NEXT: }
|
||||||
#CHECK-EMPTY:
|
#CHECK: func @get_zeros0(%arg0: tensor<*xi32>) -> tensor<2xi32> {
|
||||||
#CHECK-NEXT: func @get_zeros0(%arg0: tensor<*xi32>) -> tensor<2xi32> {
|
|
||||||
#CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "const", value = dense<[1, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>, !_tf.control)
|
#CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "const", value = dense<[1, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>, !_tf.control)
|
||||||
#CHECK-NEXT: return %0#0 : tensor<2xi32>
|
#CHECK-NEXT: return %0#0 : tensor<2xi32>
|
||||||
#CHECK-NEXT: }
|
#CHECK-NEXT: }
|
||||||
#CHECK-EMPTY:
|
#CHECK: func @identity0(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||||
#CHECK-NEXT: func @identity0(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
|
||||||
#CHECK-NEXT: return %arg0 : tensor<*xi32>
|
#CHECK-NEXT: return %arg0 : tensor<*xi32>
|
||||||
#CHECK-NEXT: }
|
#CHECK-NEXT: }
|
||||||
|
@ -303,16 +303,14 @@ versions {
|
|||||||
# CHECK-NEXT: %3:4 = "_tf.While"(%1#0, %2#0, %0#0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_INT32", "tfdtype$DT_INT32"], _lower_using_switch_merge = true, body = @while_body_60, cond = @while_cond_50, device = "", name = "while", output_shapes = ["tfshape$", "tfshape$", "tfshape$"], parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, !_tf.control)
|
# CHECK-NEXT: %3:4 = "_tf.While"(%1#0, %2#0, %0#0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_INT32", "tfdtype$DT_INT32"], _lower_using_switch_merge = true, body = @while_body_60, cond = @while_cond_50, device = "", name = "while", output_shapes = ["tfshape$", "tfshape$", "tfshape$"], parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, !_tf.control)
|
||||||
# CHECK-NEXT: return %3#2 : tensor<i32>
|
# CHECK-NEXT: return %3#2 : tensor<i32>
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
# CHECK-EMPTY:
|
# CHECK: func @while_body_60(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) {
|
||||||
# CHECK-NEXT: func @while_body_60(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) {
|
|
||||||
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Add/y", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
|
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Add/y", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
|
||||||
# CHECK-NEXT: %1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "add_1/y", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
|
# CHECK-NEXT: %1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "add_1/y", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
|
||||||
# CHECK-NEXT: %2:2 = "_tf.Add"(%arg2, %0#0) {T = "tfdtype$DT_INT32", device = "", name = "Add"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
# CHECK-NEXT: %2:2 = "_tf.Add"(%arg2, %0#0) {T = "tfdtype$DT_INT32", device = "", name = "Add"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
||||||
# CHECK-NEXT: %3:2 = "_tf.Add"(%arg0, %1#0) {T = "tfdtype$DT_INT32", device = "", name = "add_1"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
# CHECK-NEXT: %3:2 = "_tf.Add"(%arg0, %1#0) {T = "tfdtype$DT_INT32", device = "", name = "add_1"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
||||||
# CHECK-NEXT: return %3#0, %arg1, %2#0 : tensor<*xi32>, tensor<*xi32>, tensor<*xi32>
|
# CHECK-NEXT: return %3#0, %arg1, %2#0 : tensor<*xi32>, tensor<*xi32>, tensor<*xi32>
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
# CHECK-EMPTY:
|
# CHECK: func @while_cond_50(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>) -> tensor<*xi1> {
|
||||||
# CHECK-NEXT: func @while_cond_50(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>) -> tensor<*xi1> {
|
|
||||||
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Less/y", value = dense<10> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
|
# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Less/y", value = dense<10> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
|
||||||
# CHECK-NEXT: %1:2 = "_tf.Less"(%arg2, %0#0) {T = "tfdtype$DT_INT32", device = "", name = "Less"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control)
|
# CHECK-NEXT: %1:2 = "_tf.Less"(%arg2, %0#0) {T = "tfdtype$DT_INT32", device = "", name = "Less"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control)
|
||||||
# CHECK-NEXT: return %1#0 : tensor<*xi1>
|
# CHECK-NEXT: return %1#0 : tensor<*xi1>
|
||||||
|
@ -279,14 +279,12 @@ versions {
|
|||||||
# CHECK-NEXT: %5:2 = "_tf.SymbolicGradient"(%0#0, %4#0) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], device = "", f = @foo0, f._disable_call_shape_inference = true, name = "gradients/foo_grad/SymbolicGradient"} : (tensor<f32>, tensor<*xf32>) -> (tensor<f32>, !_tf.control)
|
# CHECK-NEXT: %5:2 = "_tf.SymbolicGradient"(%0#0, %4#0) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], device = "", f = @foo0, f._disable_call_shape_inference = true, name = "gradients/foo_grad/SymbolicGradient"} : (tensor<f32>, tensor<*xf32>) -> (tensor<f32>, !_tf.control)
|
||||||
# CHECK-NEXT: return
|
# CHECK-NEXT: return
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
# CHECK-EMPTY:
|
# CHECK: func @foo_grad0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32>
|
||||||
# CHECK-NEXT: func @foo_grad0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
# CHECK-NEXT: attributes {tf._disable_call_shape_inference = true} {
|
# CHECK-NEXT: attributes {tf._disable_call_shape_inference = true} {
|
||||||
# CHECK-NEXT: %0:2 = "_tf.Mul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "", name = "mul_0"} : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
|
# CHECK-NEXT: %0:2 = "_tf.Mul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "", name = "mul_0"} : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
|
||||||
# CHECK-NEXT: return %0#0 : tensor<*xf32>
|
# CHECK-NEXT: return %0#0 : tensor<*xf32>
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
# CHECK-EMPTY:
|
# CHECK: func @foo0(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||||
# CHECK-NEXT: func @foo0(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
|
||||||
# CHECK-NEXT: attributes {tf._disable_call_shape_inference = true, tf.gradient = @foo_grad0} {
|
# CHECK-NEXT: attributes {tf._disable_call_shape_inference = true, tf.gradient = @foo_grad0} {
|
||||||
# CHECK-NEXT: %0:2 = "_tf.Exp"(%arg0) {T = "tfdtype$DT_FLOAT", device = "", name = "Exp"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
|
# CHECK-NEXT: %0:2 = "_tf.Exp"(%arg0) {T = "tfdtype$DT_FLOAT", device = "", name = "Exp"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
|
||||||
# CHECK-NEXT: %1:2 = "_tf.Neg"(%arg0) {T = "tfdtype$DT_FLOAT", device = "", name = "Neg"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
|
# CHECK-NEXT: %1:2 = "_tf.Neg"(%arg0) {T = "tfdtype$DT_FLOAT", device = "", name = "Neg"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
|
||||||
|
@ -41,12 +41,10 @@ versions {
|
|||||||
# CHECK-NEXT: %1 = "_tf.bar0"() {device = "", name = "unnamed1"} : () -> !_tf.control
|
# CHECK-NEXT: %1 = "_tf.bar0"() {device = "", name = "unnamed1"} : () -> !_tf.control
|
||||||
# CHECK-NEXT: return
|
# CHECK-NEXT: return
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
# CHECK-EMPTY:
|
# CHECK: func @foo0() {
|
||||||
# CHECK-NEXT: func @foo0() {
|
|
||||||
# CHECK-NEXT: %0 = "_tf.bar0"() {device = "", name = "unnamed"} : () -> !_tf.control
|
# CHECK-NEXT: %0 = "_tf.bar0"() {device = "", name = "unnamed"} : () -> !_tf.control
|
||||||
# CHECK-NEXT: return
|
# CHECK-NEXT: return
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
# CHECK-EMPTY:
|
# CHECK: func @bar0() {
|
||||||
# CHECK-NEXT: func @bar0() {
|
|
||||||
# CHECK-NEXT: return
|
# CHECK-NEXT: return
|
||||||
# CHECK-NEXT: }
|
# CHECK-NEXT: }
|
||||||
|
@ -0,0 +1,111 @@
|
|||||||
|
# RUN: tf-mlir-translate -graphdef-to-mlir -mlir-print-debuginfo %s -o - | FileCheck %s
|
||||||
|
|
||||||
|
node {
|
||||||
|
name: "PartitionedCall"
|
||||||
|
op: "PartitionedCall"
|
||||||
|
attr {
|
||||||
|
key: "Tin"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "Tout"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_UINT8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "_gradient_op_type"
|
||||||
|
value {
|
||||||
|
s: "PartitionedCall-15"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "config"
|
||||||
|
value {
|
||||||
|
s: ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "config_proto"
|
||||||
|
value {
|
||||||
|
s: "\n\007\n\003GPU\020\000\n\007\n\003CPU\020\0012\002J\0008\001"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "executor_type"
|
||||||
|
value {
|
||||||
|
s: ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "f"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "__inference_uint_const_14"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
library {
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "__inference_uint_const_14"
|
||||||
|
output_arg {
|
||||||
|
name: "identity"
|
||||||
|
type: DT_UINT8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "Const"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_UINT8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_UINT8
|
||||||
|
tensor_shape {
|
||||||
|
}
|
||||||
|
int_val: 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "Identity"
|
||||||
|
op: "Identity"
|
||||||
|
input: "Const:output:0"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_UINT8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret {
|
||||||
|
key: "identity"
|
||||||
|
value: "Identity:output:0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
versions {
|
||||||
|
producer: 29
|
||||||
|
min_consumer: 12
|
||||||
|
}
|
||||||
|
|
||||||
|
# CHECK: func @main
|
||||||
|
# CHECK: "_tf.PartitionedCall"()
|
||||||
|
# CHECK-SAME: Tout = ["tfdtype$DT_UINT8"]
|
||||||
|
# CHECK-SAME: f = @[[FUNCTION:[A-Za-z0-9_]*]]
|
||||||
|
# CHECK: func @[[FUNCTION]]() -> tensor<!tf.uint8>
|
||||||
|
# CHECK: return {{%[0-9]*#[0-9]*}} : tensor<!tf.uint8>
|
@ -588,7 +588,7 @@ func @testSoftmax(tensor<8x16xf32>) -> tensor<8x16xf32> {
|
|||||||
// Test invalid tf.Softmax
|
// Test invalid tf.Softmax
|
||||||
func @testSoftmax(tensor<8x8x8xf32>) -> tensor<8x8x8xf32> {
|
func @testSoftmax(tensor<8x8x8xf32>) -> tensor<8x8x8xf32> {
|
||||||
^bb0(%arg0: tensor<8x8x8xf32>):
|
^bb0(%arg0: tensor<8x8x8xf32>):
|
||||||
// expected-error @+1 {{requires operand to be 2D tensor}}
|
// expected-error @+1 {{requires operand to be 1D/2D tensor}}
|
||||||
%0 = "tf.Softmax"(%arg0) {T = "tfdtype$DT_FLOAT"} : (tensor<8x8x8xf32>) -> tensor<8x8x8xf32>
|
%0 = "tf.Softmax"(%arg0) {T = "tfdtype$DT_FLOAT"} : (tensor<8x8x8xf32>) -> tensor<8x8x8xf32>
|
||||||
return %0 : tensor<8x8x8xf32>
|
return %0 : tensor<8x8x8xf32>
|
||||||
}
|
}
|
||||||
|
@ -313,7 +313,7 @@ func @nextiteration_control(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*
|
|||||||
%3:3 = tf_executor.NextIteration.Source : tensor<*xf32>
|
%3:3 = tf_executor.NextIteration.Source : tensor<*xf32>
|
||||||
tf_executor.NextIteration.Sink [%3#1] %3#0, %1#2 : tensor<*xf32>
|
tf_executor.NextIteration.Sink [%3#1] %3#0, %1#2 : tensor<*xf32>
|
||||||
// CHECK: %3:3 = tf_executor.NextIteration.Source : tensor<*xf32>
|
// CHECK: %3:3 = tf_executor.NextIteration.Source : tensor<*xf32>
|
||||||
// CHECK: tf_executor.NextIteration.Sink [%3#1] %3#0 : tensor<*xf32>
|
// CHECK: tf_executor.NextIteration.Sink [%3#1] %3#0, %1#2 : tensor<*xf32>
|
||||||
tf_executor.fetch %3#0 : tensor<*xf32>
|
tf_executor.fetch %3#0 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
|
@ -70,7 +70,7 @@ static Value* LowerCondition(Location loc, Value* value, OpBuilder* builder) {
|
|||||||
// that is compatible for tensor cast.
|
// that is compatible for tensor cast.
|
||||||
//
|
//
|
||||||
static Operation* CallFn(Location loc,
|
static Operation* CallFn(Location loc,
|
||||||
const std::function<Value*(int)>& get_arg, Function fn,
|
const std::function<Value*(int)>& get_arg, FuncOp fn,
|
||||||
OpBuilder* builder) {
|
OpBuilder* builder) {
|
||||||
FunctionType fn_type = fn.getType();
|
FunctionType fn_type = fn.getType();
|
||||||
llvm::SmallVector<Value*, 4> operands;
|
llvm::SmallVector<Value*, 4> operands;
|
||||||
@ -153,9 +153,9 @@ static LogicalResult LowerIfOp(IfOp op) {
|
|||||||
Value* cond_i1 = LowerCondition(loc, op.getCondition(), &builder);
|
Value* cond_i1 = LowerCondition(loc, op.getCondition(), &builder);
|
||||||
if (!cond_i1) return failure();
|
if (!cond_i1) return failure();
|
||||||
|
|
||||||
auto module = op_inst->getParentOfType<Module>();
|
auto module = op_inst->getParentOfType<ModuleOp>();
|
||||||
auto then_fn = module.getNamedFunction(op.getThen());
|
auto then_fn = module.lookupSymbol<FuncOp>(op.getThen());
|
||||||
auto else_fn = module.getNamedFunction(op.getElse());
|
auto else_fn = module.lookupSymbol<FuncOp>(op.getElse());
|
||||||
|
|
||||||
// Split the basic block before the 'if'. The new dest will be our merge
|
// Split the basic block before the 'if'. The new dest will be our merge
|
||||||
// point.
|
// point.
|
||||||
@ -210,9 +210,9 @@ static LogicalResult LowerWhileOp(WhileOp op) {
|
|||||||
|
|
||||||
OpBuilder builder(op_inst);
|
OpBuilder builder(op_inst);
|
||||||
|
|
||||||
auto module = op_inst->getParentOfType<Module>();
|
auto module = op_inst->getParentOfType<ModuleOp>();
|
||||||
auto cond_fn = module.getNamedFunction(op.getCond());
|
auto cond_fn = module.lookupSymbol<FuncOp>(op.getCond());
|
||||||
auto body_fn = module.getNamedFunction(op.getBody());
|
auto body_fn = module.lookupSymbol<FuncOp>(op.getBody());
|
||||||
|
|
||||||
// Split the block containing the While op into two blocks. One containing
|
// Split the block containing the While op into two blocks. One containing
|
||||||
// operations before the While op and other containing the rest. Create two
|
// operations before the While op and other containing the rest. Create two
|
||||||
|
@ -35,7 +35,6 @@ namespace TFControlFlow {
|
|||||||
FunctionPassBase *CreateRaiseTFControlFlowPass();
|
FunctionPassBase *CreateRaiseTFControlFlowPass();
|
||||||
|
|
||||||
} // namespace TFControlFlow
|
} // namespace TFControlFlow
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_
|
||||||
|
@ -89,7 +89,7 @@ std::vector<GraphOptimizationPass*> GraphOptPass::FindPassIds() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GraphOptPass::runOnModule() {
|
void GraphOptPass::runOnModule() {
|
||||||
mlir::Module module_in = getModule();
|
mlir::ModuleOp module_in = getModule();
|
||||||
mlir::MLIRContext& ctx = getContext();
|
mlir::MLIRContext& ctx = getContext();
|
||||||
|
|
||||||
// Convert MLIR to Graph
|
// Convert MLIR to Graph
|
||||||
|
@ -0,0 +1,248 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This transformation pass transforms MLIR TF contol dialect into a combination
|
||||||
|
// of the TF and TF executor dialects.
|
||||||
|
//
|
||||||
|
// !! This code is only intended for migration purpose and will be deleted when
|
||||||
|
// !! the importer is updated to directly emit the tf_executor dialect.
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/Sequence.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "tf-ctl-to-executor"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// This pass checks if a function contains only operations in the TF control
|
||||||
|
// dialect and converts it to a mix of the tf_executor and tf dialects.
|
||||||
|
// Operations that exist in the tf_executor dialects are used directly,
|
||||||
|
// otherwise _tf operations are wrapped in an island and the _ prefix is
|
||||||
|
// removed. Control dependencies are moved to be handled by the island itself.
|
||||||
|
struct ControlToExecutorDialectConversion
|
||||||
|
: public FunctionPass<ControlToExecutorDialectConversion> {
|
||||||
|
void runOnFunction() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
tf_executor::IslandOp CreateIslandForOp(Operation *op, OpBuilder *builder);
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
static bool IsUnderscoredTFOp(Operation *op) {
|
||||||
|
return op->getName().getStringRef().startswith("_tf.");
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool HasOnlyTFControlOperations(FuncOp function) {
|
||||||
|
return llvm::all_of(function, [](Block &block) {
|
||||||
|
return llvm::all_of(block, [](Operation &op) {
|
||||||
|
return IsUnderscoredTFOp(&op) || isa<ReturnOp>(op);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
tf_executor::IslandOp ControlToExecutorDialectConversion::CreateIslandForOp(
|
||||||
|
Operation *op, OpBuilder *builder) {
|
||||||
|
// Create a new region for the tf_executor.island body
|
||||||
|
SmallVector<Value *, 8> operands;
|
||||||
|
for (Value *operand : op->getOperands())
|
||||||
|
if (operand->getType().isa<tf_executor::ControlType>())
|
||||||
|
operands.push_back(operand);
|
||||||
|
SmallVector<Type, 8> types;
|
||||||
|
for (Type result_type : op->getResultTypes())
|
||||||
|
if (!result_type.isa<TFControlFlow::TFControlType>())
|
||||||
|
types.push_back(result_type);
|
||||||
|
types.push_back(tf_executor::ControlType::get(&getContext()));
|
||||||
|
|
||||||
|
auto island = builder->create<tf_executor::IslandOp>(
|
||||||
|
op->getLoc(), types, operands, ArrayRef<NamedAttribute>{});
|
||||||
|
island.body().push_back(new Block);
|
||||||
|
|
||||||
|
return island;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ControlToExecutorDialectConversion::runOnFunction() {
|
||||||
|
if (!HasOnlyTFControlOperations(getFunction())) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "Function has unsupported operation, skip "
|
||||||
|
"tf_executor dialect conversion\n");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (getFunction().getBlocks().size() != 1) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "Expect single block function, , skip "
|
||||||
|
"tf_executor dialect conversion\n");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Block &body = getFunction().getBody().front();
|
||||||
|
OpBuilder builder(&body, body.begin());
|
||||||
|
|
||||||
|
// Create a new tf_executor.graph at the beginning of the function.
|
||||||
|
auto graph_op = builder.create<tf_executor::GraphOp>(
|
||||||
|
getFunction().getLoc(), getFunction().getType().getResults());
|
||||||
|
graph_op.body().push_back(new Block);
|
||||||
|
builder.setInsertionPointToEnd(&graph_op.GetBody());
|
||||||
|
llvm::StringMap<tf_executor::NextIterationSourceOp> frame_name_to_loop;
|
||||||
|
|
||||||
|
// Loop over operations in the function and move them into the graph region.
|
||||||
|
for (Operation &op : llvm::make_early_inc_range(body)) {
|
||||||
|
// Skip the just-created tf_executor.graph.
|
||||||
|
if (isa<tf_executor::GraphOp>(op)) continue;
|
||||||
|
|
||||||
|
// This is the new operation that will replace the current one in the graph.
|
||||||
|
Operation *replacement = nullptr;
|
||||||
|
if (op.isKnownTerminator()) {
|
||||||
|
// This is the return of the function, we will create a fetch in the graph
|
||||||
|
// matching the operands of the returns. The return is then updated to
|
||||||
|
// take as operands the results of the tf_executor.graph operation.
|
||||||
|
SmallVector<Value *, 8> ret_vals;
|
||||||
|
for (Value *operand : op.getOperands()) ret_vals.push_back(operand);
|
||||||
|
for (auto &graph_result : llvm::enumerate(graph_op.getResults()))
|
||||||
|
op.setOperand(graph_result.index(), graph_result.value());
|
||||||
|
builder.create<tf_executor::FetchOp>(getFunction().getLoc(), ret_vals);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
assert(IsUnderscoredTFOp(&op) && "Expected only _tf operations");
|
||||||
|
|
||||||
|
// The operands and types arrays are used to create the tf_executor ops.
|
||||||
|
SmallVector<Value *, 8> operands;
|
||||||
|
operands.append(op.getOperands().begin(), op.getOperands().end());
|
||||||
|
SmallVector<Type, 8> types;
|
||||||
|
for (Type result_type : op.getResultTypes()) {
|
||||||
|
if (result_type.isa<TFControlFlow::TFControlType>())
|
||||||
|
types.push_back(tf_executor::ControlType::get(&getContext()));
|
||||||
|
else
|
||||||
|
types.push_back(result_type);
|
||||||
|
}
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
// Match the specific operation that has a tf_executor equivalent, the
|
||||||
|
// others will be wrapped in an island.
|
||||||
|
|
||||||
|
// FIXME: StringSwitch
|
||||||
|
|
||||||
|
if (op.getName().getStringRef() == "_tf.Switch") {
|
||||||
|
replacement = builder.create<tf_executor::SwitchOp>(
|
||||||
|
loc, types, operands, ArrayRef<NamedAttribute>{});
|
||||||
|
} else if (op.getName().getStringRef() == "_tf.SwitchN") {
|
||||||
|
replacement = builder.create<tf_executor::SwitchNOp>(
|
||||||
|
loc, types, operands, ArrayRef<NamedAttribute>{});
|
||||||
|
} else if (op.getName().getStringRef() == "_tf.Merge") {
|
||||||
|
replacement = builder.create<tf_executor::MergeOp>(
|
||||||
|
loc, types, operands, ArrayRef<NamedAttribute>{});
|
||||||
|
} else if (op.getName().getStringRef() == "_tf.NextIteration.source") {
|
||||||
|
replacement = builder.create<tf_executor::NextIterationSourceOp>(
|
||||||
|
loc, op.getResult(0)->getType(), operands);
|
||||||
|
// Record a mapping of the name to the nextiteration.source so that when
|
||||||
|
// we convert the sink we can get the token.
|
||||||
|
StringAttr frame = op.getAttrOfType<StringAttr>("name");
|
||||||
|
assert(!frame.getValue().empty());
|
||||||
|
frame_name_to_loop[frame.getValue()] =
|
||||||
|
cast<tf_executor::NextIterationSourceOp>(replacement);
|
||||||
|
// Replace the results here since the _tf source does not produce a token
|
||||||
|
// there isn't a mapping for the new result #1.
|
||||||
|
op.getResult(0)->replaceAllUsesWith(replacement->getResult(0));
|
||||||
|
for (int i : llvm::seq<int>(1, op.getNumResults()))
|
||||||
|
op.getResult(i)->replaceAllUsesWith(replacement->getResult(i + 1));
|
||||||
|
replacement->setAttrs(op.getAttrList());
|
||||||
|
op.erase();
|
||||||
|
continue;
|
||||||
|
} else if (op.getName().getStringRef() == "_tf.NextIteration.sink") {
|
||||||
|
StringAttr frame = op.getAttrOfType<StringAttr>("name");
|
||||||
|
assert(!frame.getValue().empty());
|
||||||
|
tf_executor::NextIterationSourceOp srcOp =
|
||||||
|
frame_name_to_loop[frame.getValue()];
|
||||||
|
replacement = builder.create<tf_executor::NextIterationSinkOp>(
|
||||||
|
loc, srcOp.token(), operands, ArrayRef<NamedAttribute>{});
|
||||||
|
replacement->setAttrs(op.getAttrList());
|
||||||
|
op.erase();
|
||||||
|
continue;
|
||||||
|
} else if (op.getName().getStringRef() == "_tf.LoopCond") {
|
||||||
|
replacement = builder.create<tf_executor::LoopCondOp>(
|
||||||
|
loc, types, operands, ArrayRef<NamedAttribute>{});
|
||||||
|
} else if (op.getName().getStringRef() == "_tf.Enter") {
|
||||||
|
replacement = builder.create<tf_executor::EnterOp>(
|
||||||
|
loc, types, operands, ArrayRef<NamedAttribute>{});
|
||||||
|
} else if (op.getName().getStringRef() == "_tf.Exit") {
|
||||||
|
replacement = builder.create<tf_executor::ExitOp>(
|
||||||
|
loc, types, operands, ArrayRef<NamedAttribute>{});
|
||||||
|
} else if (op.getName().getStringRef() == "_tf.ControlTrigger") {
|
||||||
|
replacement =
|
||||||
|
builder.create<tf_executor::ControlTriggerOp>(loc, operands);
|
||||||
|
} else {
|
||||||
|
tf_executor::IslandOp island = CreateIslandForOp(&op, &builder);
|
||||||
|
replacement = island.getOperation();
|
||||||
|
|
||||||
|
// General case, drop the leading _ off the name and wrap in an island.
|
||||||
|
OperationState result(loc, op.getName().getStringRef().drop_front());
|
||||||
|
|
||||||
|
// Only the non-control operands are carried over, the island is handling
|
||||||
|
// the control input.
|
||||||
|
for (Value *operand : op.getOperands())
|
||||||
|
if (!operand->getType().isa<tf_executor::ControlType>())
|
||||||
|
result.operands.push_back(operand);
|
||||||
|
|
||||||
|
// Add a result type for each non-control result we find
|
||||||
|
bool sawControlResult = false;
|
||||||
|
for (Type result_type : op.getResultTypes()) {
|
||||||
|
if (result_type.isa<TFControlFlow::TFControlType>()) {
|
||||||
|
sawControlResult = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// We assume all control inputs are at the end of the result list.
|
||||||
|
assert(!sawControlResult && "all control results must be last");
|
||||||
|
result.types.push_back(result_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the operation inside the island
|
||||||
|
OpBuilder island_builder(&island.GetBody());
|
||||||
|
Operation *inner_op = island_builder.createOperation(result);
|
||||||
|
inner_op->setAttrs(op.getAttrList());
|
||||||
|
|
||||||
|
// Add the terminator for the island
|
||||||
|
SmallVector<Value *, 8> ret_vals(inner_op->getResults());
|
||||||
|
island_builder.create<tf_executor::YieldOp>(loc, ret_vals);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy the attributes from the original operation to the replacement and
|
||||||
|
// remap the results.
|
||||||
|
if (!isa<tf_executor::IslandOp>(replacement))
|
||||||
|
replacement->setAttrs(op.getAttrList());
|
||||||
|
for (int i : llvm::seq<int>(0, op.getNumResults()))
|
||||||
|
op.getResult(i)->replaceAllUsesWith(replacement->getResult(i));
|
||||||
|
op.erase();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionPassBase *CreateTFControlToExecutorDialectConversion() {
|
||||||
|
return new ControlToExecutorDialectConversion();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
static mlir::PassRegistration<mlir::ControlToExecutorDialectConversion> pass(
|
||||||
|
"tf-control-to-executor-conversion",
|
||||||
|
"Transform from TF control dialect to TF executor dialect.");
|
@ -88,23 +88,23 @@ class Exporter {
|
|||||||
// one entry function, which is identified by name "main". This entry function
|
// one entry function, which is identified by name "main". This entry function
|
||||||
// is converted to the base of the graph graph. The rest of the functions are
|
// is converted to the base of the graph graph. The rest of the functions are
|
||||||
// converted to the library functions in that graph.
|
// converted to the library functions in that graph.
|
||||||
static Status Convert(mlir::Module module, const ExporterConfigs& configs,
|
static Status Convert(mlir::ModuleOp module, const ExporterConfigs& configs,
|
||||||
std::unique_ptr<Graph>* graph,
|
std::unique_ptr<Graph>* graph,
|
||||||
FunctionLibraryDefinition* flib_def);
|
FunctionLibraryDefinition* flib_def);
|
||||||
|
|
||||||
// Converts a given Function to a FunctionDef and adds it to the function
|
// Converts a given FuncOp to a FunctionDef and adds it to the function
|
||||||
// definition library
|
// definition library
|
||||||
static Status ConvertLibFunction(const ExporterConfigs& configs,
|
static Status ConvertLibFunction(const ExporterConfigs& configs,
|
||||||
const Dialect* tf_dialect,
|
const Dialect* tf_dialect,
|
||||||
mlir::Function function,
|
mlir::FuncOp function,
|
||||||
FunctionDefLibrary* flib);
|
FunctionDefLibrary* flib);
|
||||||
// Converts the given CFG Function to a Graph. The arguments and returns of
|
// Converts the given FuncOp to a Graph. The arguments and returns of
|
||||||
// function are added to the graph with special op names kArgOp and kRetOp.
|
// function are added to the graph with special op names kArgOp and kRetOp.
|
||||||
// Later on, this graph can be converted a function definition and added to
|
// Later on, this graph can be converted a function definition and added to
|
||||||
// another graph.
|
// another graph.
|
||||||
static StatusOr<std::unique_ptr<Graph>> Convert(
|
static StatusOr<std::unique_ptr<Graph>> Convert(
|
||||||
const ExporterConfigs& configs, const Dialect* tf_dialect,
|
const ExporterConfigs& configs, const Dialect* tf_dialect,
|
||||||
mlir::Function function, FunctionDefLibrary* flib);
|
mlir::FuncOp function, FunctionDefLibrary* flib);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit Exporter(Graph* graph, const Dialect* tf_dialect)
|
explicit Exporter(Graph* graph, const Dialect* tf_dialect)
|
||||||
@ -376,11 +376,11 @@ Status Exporter::AddNextIterationNode(mlir::Operation* inst) {
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<Graph>> Exporter::Convert(const ExporterConfigs& confs,
|
StatusOr<std::unique_ptr<Graph>> Exporter::Convert(const ExporterConfigs& confs,
|
||||||
const Dialect* tf_dialect,
|
const Dialect* tf_dialect,
|
||||||
mlir::Function function,
|
mlir::FuncOp function,
|
||||||
FunctionDefLibrary* flib) {
|
FunctionDefLibrary* flib) {
|
||||||
if (function.getBlocks().size() != 1) {
|
if (function.getBlocks().size() != 1) {
|
||||||
return errors::FailedPrecondition(
|
return errors::FailedPrecondition(
|
||||||
"Input Function must have only one basic block!");
|
"Input FuncOp must have only one basic block!");
|
||||||
}
|
}
|
||||||
mlir::Block& block = function.front();
|
mlir::Block& block = function.front();
|
||||||
|
|
||||||
@ -433,7 +433,7 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(const ExporterConfigs& confs,
|
|||||||
mlir::Type type = arg->getType();
|
mlir::Type type = arg->getType();
|
||||||
if (!type.isa<mlir::TensorType>()) {
|
if (!type.isa<mlir::TensorType>()) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Functions arguments must have tensor types. Found ",
|
"FuncOps arguments must have tensor types. Found ",
|
||||||
mlir::debugString(type), " in function ", function.getName().str());
|
mlir::debugString(type), " in function ", function.getName().str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -448,7 +448,9 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(const ExporterConfigs& confs,
|
|||||||
// definition library
|
// definition library
|
||||||
// TODO(prakalps): If two functions have cyclic dependence, this will
|
// TODO(prakalps): If two functions have cyclic dependence, this will
|
||||||
// introduce an infinite loop.
|
// introduce an infinite loop.
|
||||||
auto func = function.getModule().getNamedFunction(op_name.ValueOrDie());
|
auto func =
|
||||||
|
function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
|
||||||
|
op_name.ValueOrDie());
|
||||||
if (func != nullptr) {
|
if (func != nullptr) {
|
||||||
TF_RETURN_IF_ERROR(ConvertLibFunction(confs, tf_dialect, func, flib));
|
TF_RETURN_IF_ERROR(ConvertLibFunction(confs, tf_dialect, func, flib));
|
||||||
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
|
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
|
||||||
@ -483,7 +485,7 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(const ExporterConfigs& confs,
|
|||||||
|
|
||||||
Status Exporter::ConvertLibFunction(const ExporterConfigs& configs,
|
Status Exporter::ConvertLibFunction(const ExporterConfigs& configs,
|
||||||
const Dialect* tf_dialect,
|
const Dialect* tf_dialect,
|
||||||
mlir::Function function,
|
mlir::FuncOp function,
|
||||||
FunctionDefLibrary* flib) {
|
FunctionDefLibrary* flib) {
|
||||||
// First look for the function in the current function library. If found,
|
// First look for the function in the current function library. If found,
|
||||||
// nothing needs to be done.
|
// nothing needs to be done.
|
||||||
@ -510,8 +512,10 @@ Status Exporter::ConvertLibFunction(const ExporterConfigs& configs,
|
|||||||
// Checks for gradient attribute. If present converts the gradient function
|
// Checks for gradient attribute. If present converts the gradient function
|
||||||
// and populates the GradientDef.
|
// and populates the GradientDef.
|
||||||
auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
|
auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
|
||||||
if (auto attr = function.getAttrOfType<mlir::FunctionAttr>(grad_string)) {
|
if (auto attr = function.getAttrOfType<mlir::SymbolRefAttr>(grad_string)) {
|
||||||
auto grad_func = function.getModule().getNamedFunction(attr.getValue());
|
auto grad_func =
|
||||||
|
function.getParentOfType<mlir::ModuleOp>().lookupSymbol<mlir::FuncOp>(
|
||||||
|
attr.getValue());
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ConvertLibFunction(configs, tf_dialect, grad_func, flib));
|
ConvertLibFunction(configs, tf_dialect, grad_func, flib));
|
||||||
GradientDef grad;
|
GradientDef grad;
|
||||||
@ -531,15 +535,15 @@ Status Exporter::ConvertLibFunction(const ExporterConfigs& configs,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Exporter::Convert(mlir::Module module, const ExporterConfigs& configs,
|
Status Exporter::Convert(mlir::ModuleOp module, const ExporterConfigs& configs,
|
||||||
std::unique_ptr<Graph>* graph,
|
std::unique_ptr<Graph>* graph,
|
||||||
FunctionLibraryDefinition* flib_def) {
|
FunctionLibraryDefinition* flib_def) {
|
||||||
mlir::Identifier entry_func_id =
|
mlir::Identifier entry_func_id =
|
||||||
mlir::Identifier::get("main", module.getContext());
|
mlir::Identifier::get("main", module.getContext());
|
||||||
absl::optional<mlir::Function> entry_func;
|
absl::optional<mlir::FuncOp> entry_func;
|
||||||
FunctionDefLibrary flib;
|
FunctionDefLibrary flib;
|
||||||
auto tf_dialect = module.getContext()->getRegisteredDialect("tf");
|
auto tf_dialect = module.getContext()->getRegisteredDialect("tf");
|
||||||
for (auto function : module.getOps<mlir::Function>()) {
|
for (auto function : module.getOps<mlir::FuncOp>()) {
|
||||||
if (function.isExternal())
|
if (function.isExternal())
|
||||||
return errors::FailedPrecondition("External functions not supported");
|
return errors::FailedPrecondition("External functions not supported");
|
||||||
|
|
||||||
@ -567,14 +571,14 @@ Status Exporter::Convert(mlir::Module module, const ExporterConfigs& configs,
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status ConvertMlirToGraph(mlir::Module module, const ExporterConfigs& confs,
|
Status ConvertMlirToGraph(mlir::ModuleOp module, const ExporterConfigs& confs,
|
||||||
std::unique_ptr<Graph>* graph,
|
std::unique_ptr<Graph>* graph,
|
||||||
FunctionLibraryDefinition* flib_def) {
|
FunctionLibraryDefinition* flib_def) {
|
||||||
return Exporter::Convert(module, confs, graph, flib_def);
|
return Exporter::Convert(module, confs, graph, flib_def);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
|
StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
|
||||||
mlir::Module module, const ExporterConfigs& confs) {
|
mlir::ModuleOp module, const ExporterConfigs& confs) {
|
||||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(),
|
FunctionLibraryDefinition flib_def(OpRegistry::Global(),
|
||||||
FunctionDefLibrary());
|
FunctionDefLibrary());
|
||||||
auto graph = absl::make_unique<Graph>(flib_def);
|
auto graph = absl::make_unique<Graph>(flib_def);
|
||||||
|
@ -32,13 +32,13 @@ using stream_executor::port::StatusOr;
|
|||||||
|
|
||||||
// Given an MLIR module, returns a GraphDef.
|
// Given an MLIR module, returns a GraphDef.
|
||||||
StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
|
StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
|
||||||
mlir::Module module, const ExporterConfigs& configs);
|
mlir::ModuleOp module, const ExporterConfigs& configs);
|
||||||
|
|
||||||
// Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition.
|
// Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition.
|
||||||
// The "main" function of the module is stored in the graph and the rest of
|
// The "main" function of the module is stored in the graph and the rest of
|
||||||
// functions are stored in the library.
|
// functions are stored in the library.
|
||||||
stream_executor::port::Status ConvertMlirToGraph(
|
stream_executor::port::Status ConvertMlirToGraph(
|
||||||
mlir::Module module, const ExporterConfigs& confs,
|
mlir::ModuleOp module, const ExporterConfigs& confs,
|
||||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def);
|
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
@ -86,7 +87,7 @@ class Importer {
|
|||||||
|
|
||||||
explicit Importer(
|
explicit Importer(
|
||||||
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
|
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
|
||||||
const NodeSpecs& specs, mlir::Module module,
|
const NodeSpecs& specs, mlir::ModuleOp module,
|
||||||
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name)
|
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name)
|
||||||
: module_(module),
|
: module_(module),
|
||||||
context_(module.getContext()),
|
context_(module.getContext()),
|
||||||
@ -158,8 +159,8 @@ class Importer {
|
|||||||
return ::tensorflow::ConvertTensorProto(value, builder_.get());
|
return ::tensorflow::ConvertTensorProto(value, builder_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts func name in graphdef to mlir::FunctionAttribute.
|
// Converts func name in graphdef to mlir::SymbolRefAttribute.
|
||||||
StatusOr<mlir::FunctionAttr> ConvertFunctionCallName(
|
StatusOr<mlir::SymbolRefAttr> ConvertFunctionCallName(
|
||||||
const std::string& func_name);
|
const std::string& func_name);
|
||||||
|
|
||||||
// Converts the given non-function-call AttrValue to an MLIR Attribute.
|
// Converts the given non-function-call AttrValue to an MLIR Attribute.
|
||||||
@ -262,7 +263,7 @@ class Importer {
|
|||||||
using NodeValueMap = absl::flat_hash_map<int, mlir::Operation*>;
|
using NodeValueMap = absl::flat_hash_map<int, mlir::Operation*>;
|
||||||
|
|
||||||
std::unique_ptr<mlir::OpBuilder> builder_;
|
std::unique_ptr<mlir::OpBuilder> builder_;
|
||||||
mlir::Module module_;
|
mlir::ModuleOp module_;
|
||||||
mlir::MLIRContext* context_;
|
mlir::MLIRContext* context_;
|
||||||
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
|
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
|
||||||
const FunctionLibraryDefinition& graph_flib_;
|
const FunctionLibraryDefinition& graph_flib_;
|
||||||
@ -611,12 +612,12 @@ Status Importer::ConvertFunctionCallAttribute(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<mlir::FunctionAttr> Importer::ConvertFunctionCallName(
|
StatusOr<mlir::SymbolRefAttr> Importer::ConvertFunctionCallName(
|
||||||
const std::string& func_name) {
|
const std::string& func_name) {
|
||||||
TF_RETURN_IF_ERROR(ConvertLibFunction(func_name));
|
TF_RETURN_IF_ERROR(ConvertLibFunction(func_name));
|
||||||
auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name];
|
auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name];
|
||||||
auto func = module_.getNamedFunction(mlir_func_name);
|
auto func = module_.lookupSymbol<mlir::FuncOp>(mlir_func_name);
|
||||||
return builder_->getFunctionAttr(func);
|
return builder_->getSymbolRefAttr(func);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<mlir::Attribute> Importer::ConvertAttributeValue(
|
StatusOr<mlir::Attribute> Importer::ConvertAttributeValue(
|
||||||
@ -721,8 +722,8 @@ Status Importer::ConvertLibFunction(const std::string& func_name) {
|
|||||||
if (!grad_func_name.empty()) {
|
if (!grad_func_name.empty()) {
|
||||||
TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name));
|
TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name));
|
||||||
auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name];
|
auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name];
|
||||||
auto grad_func = module_.getNamedFunction(mlir_grad_func_name);
|
auto grad_func = module_.lookupSymbol<mlir::FuncOp>(mlir_grad_func_name);
|
||||||
auto gradient_attr = builder_->getFunctionAttr(grad_func);
|
auto gradient_attr = builder_->getSymbolRefAttr(grad_func);
|
||||||
auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
|
auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
|
||||||
attributes.push_back(builder_->getNamedAttr(grad_string, gradient_attr));
|
attributes.push_back(builder_->getNamedAttr(grad_string, gradient_attr));
|
||||||
}
|
}
|
||||||
@ -1151,8 +1152,8 @@ Status Importer::Convert(llvm::StringRef func_name,
|
|||||||
const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
|
const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
|
||||||
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
|
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
|
||||||
// TODO(b/122040776): Uses debug info for FunctionDef.
|
// TODO(b/122040776): Uses debug info for FunctionDef.
|
||||||
auto function = mlir::Function::create(mlir::UnknownLoc::get(context_),
|
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
|
||||||
func_name, func_type, attrs);
|
func_name, func_type, attrs);
|
||||||
|
|
||||||
module_.push_back(function);
|
module_.push_back(function);
|
||||||
builder_ = absl::make_unique<mlir::OpBuilder>(function.getBody());
|
builder_ = absl::make_unique<mlir::OpBuilder>(function.getBody());
|
||||||
@ -1291,7 +1292,8 @@ StatusOr<mlir::OwningModuleRef> Importer::Convert(
|
|||||||
mlir::MLIRContext* context, const Graph& graph,
|
mlir::MLIRContext* context, const Graph& graph,
|
||||||
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
|
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
|
||||||
const NodeSpecs& specs) {
|
const NodeSpecs& specs) {
|
||||||
mlir::OwningModuleRef module = mlir::Module::create(context);
|
mlir::OwningModuleRef module =
|
||||||
|
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
|
||||||
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
||||||
Importer importer(flib_def, debug_info, specs, module.get(),
|
Importer importer(flib_def, debug_info, specs, module.get(),
|
||||||
&tf_name_to_mlir_name);
|
&tf_name_to_mlir_name);
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_
|
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_
|
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_
|
||||||
|
|
||||||
|
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/SourceMgr.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||||
@ -119,7 +120,7 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
|||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
inst.emitWarning()
|
inst.emitWarning()
|
||||||
<< "Skipping splat converstion for "
|
<< "Skipping splat conversion for "
|
||||||
<< "an unsupported attribute type " << element_type;
|
<< "an unsupported attribute type " << element_type;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -63,7 +63,7 @@ static TranslateToMLIRRegistration GraphdefToSplattedMlirTranslate(
|
|||||||
"graphdef-to-splatted-mlir", GraphdefToSplattedMlirTranslateFunction);
|
"graphdef-to-splatted-mlir", GraphdefToSplattedMlirTranslateFunction);
|
||||||
|
|
||||||
static LogicalResult MlirToGraphdefTranslateFunction(
|
static LogicalResult MlirToGraphdefTranslateFunction(
|
||||||
Module module, llvm::StringRef output_filename) {
|
ModuleOp module, llvm::StringRef output_filename) {
|
||||||
if (!module) return failure();
|
if (!module) return failure();
|
||||||
|
|
||||||
std::error_code error;
|
std::error_code error;
|
||||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "llvm/Support/ToolOutputFile.h"
|
#include "llvm/Support/ToolOutputFile.h"
|
||||||
|
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||||
@ -22,8 +23,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
static mlir::Operation* ExtractOnlyOp(mlir::Module module) {
|
static mlir::Operation* ExtractOnlyOp(mlir::ModuleOp module) {
|
||||||
mlir::Function fn = module.getNamedFunction("main");
|
mlir::FuncOp fn = module.lookupSymbol<mlir::FuncOp>("main");
|
||||||
if (!fn) return nullptr;
|
if (!fn) return nullptr;
|
||||||
|
|
||||||
if (fn.getBlocks().size() != 1) return nullptr;
|
if (fn.getBlocks().size() != 1) return nullptr;
|
||||||
@ -38,7 +39,8 @@ static mlir::Operation* ExtractOnlyOp(mlir::Module module) {
|
|||||||
return &block.front();
|
return &block.front();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult MlirToTfNodeDef(Module module, llvm::StringRef filename) {
|
static LogicalResult MlirToTfNodeDef(ModuleOp module,
|
||||||
|
llvm::StringRef filename) {
|
||||||
auto* context = module.getContext();
|
auto* context = module.getContext();
|
||||||
|
|
||||||
auto file = openOutputFile(filename);
|
auto file = openOutputFile(filename);
|
||||||
|
@ -46,19 +46,15 @@ Status ConvertDataType(const DataType& dtype, Builder builder, Type* type) {
|
|||||||
*type = builder.getIntegerType(1);
|
*type = builder.getIntegerType(1);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
case DT_INT8:
|
case DT_INT8:
|
||||||
case DT_UINT8:
|
|
||||||
*type = builder.getIntegerType(8);
|
*type = builder.getIntegerType(8);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
case DT_INT16:
|
case DT_INT16:
|
||||||
case DT_UINT16:
|
|
||||||
*type = builder.getIntegerType(16);
|
*type = builder.getIntegerType(16);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
case DT_INT32:
|
case DT_INT32:
|
||||||
case DT_UINT32:
|
|
||||||
*type = builder.getIntegerType(32);
|
*type = builder.getIntegerType(32);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
case DT_INT64:
|
case DT_INT64:
|
||||||
case DT_UINT64:
|
|
||||||
*type = builder.getIntegerType(64);
|
*type = builder.getIntegerType(64);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
case DT_BFLOAT16:
|
case DT_BFLOAT16:
|
||||||
|
@ -115,7 +115,7 @@ Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvertAttribute(const mlir::FunctionAttr& attr, AttrValue* value) {
|
Status ConvertAttribute(const mlir::SymbolRefAttr& attr, AttrValue* value) {
|
||||||
value->mutable_func()->set_name(attr.getValue());
|
value->mutable_func()->set_name(attr.getValue());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -149,7 +149,7 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) {
|
|||||||
TensorProto tensor;
|
TensorProto tensor;
|
||||||
TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor));
|
TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor));
|
||||||
*list->add_tensor() = tensor;
|
*list->add_tensor() = tensor;
|
||||||
} else if (auto attr = a.dyn_cast<mlir::FunctionAttr>()) {
|
} else if (auto attr = a.dyn_cast<mlir::SymbolRefAttr>()) {
|
||||||
AttrValue attrVal;
|
AttrValue attrVal;
|
||||||
TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attrVal));
|
TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attrVal));
|
||||||
*list->add_func() = attrVal.func();
|
*list->add_func() = attrVal.func();
|
||||||
@ -218,8 +218,8 @@ Status ConvertAttributes(const llvm::ArrayRef<mlir::NamedAttribute> attrs,
|
|||||||
}
|
}
|
||||||
AttrValue value;
|
AttrValue value;
|
||||||
switch (attr.getKind()) {
|
switch (attr.getKind()) {
|
||||||
case mlir::StandardAttributes::Function: {
|
case mlir::StandardAttributes::SymbolRef: {
|
||||||
auto func_attr = attr.cast<mlir::FunctionAttr>();
|
auto func_attr = attr.cast<mlir::SymbolRefAttr>();
|
||||||
value.mutable_func()->set_name(func_attr.getValue());
|
value.mutable_func()->set_name(func_attr.getValue());
|
||||||
func_call_attrs[string(name)] = value;
|
func_call_attrs[string(name)] = value;
|
||||||
continue;
|
continue;
|
||||||
|
@ -298,8 +298,8 @@ def XLA_WhileOp: XLA_Op<"while", [NoSideEffect, SameOperandsAndResultType]> {
|
|||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Variadic<XLA_TensorOrTuple>:$val,
|
Variadic<XLA_TensorOrTuple>:$val,
|
||||||
FunctionAttr:$cond,
|
SymbolRefAttr:$cond,
|
||||||
FunctionAttr:$body
|
SymbolRefAttr:$body
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs Variadic<XLA_TensorOrTuple>:$res);
|
let results = (outs Variadic<XLA_TensorOrTuple>:$res);
|
||||||
@ -320,7 +320,7 @@ def XLA_ReduceOp: XLA_Op<"reduce", [NoSideEffect]> {
|
|||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Variadic<XLA_TensorOrTuple>:$operands_and_init,
|
Variadic<XLA_TensorOrTuple>:$operands_and_init,
|
||||||
FunctionAttr:$computation,
|
SymbolRefAttr:$computation,
|
||||||
ElementsAttr:$dimensions
|
ElementsAttr:$dimensions
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -116,12 +116,14 @@ class RandomOpsTest(xla_test.XLATestCase):
|
|||||||
def rng(dtype):
|
def rng(dtype):
|
||||||
return random_ops.truncated_normal(shape=[2], dtype=dtype)
|
return random_ops.truncated_normal(shape=[2], dtype=dtype)
|
||||||
|
|
||||||
self._testRngIsNotConstant(rng, dtypes.float32)
|
# TODO(b/34339814): make this test work with 16 bit float types.
|
||||||
|
for dtype in self._random_types() & {np.float32, np.float64}:
|
||||||
|
self._testRngIsNotConstant(rng, dtype)
|
||||||
|
|
||||||
def testTruncatedNormalIsInRange(self):
|
def testTruncatedNormalIsInRange(self):
|
||||||
count = 10000000
|
count = 10000000
|
||||||
# TODO(b/34339814): make this test work with 16 bit float types.
|
# TODO(b/34339814): make this test work with 16 bit float types.
|
||||||
for dtype in self._random_types() & {dtypes.float32, dtypes.float64}:
|
for dtype in self._random_types() & {np.float32, np.float64}:
|
||||||
with self.session() as sess:
|
with self.session() as sess:
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
|
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
|
||||||
|
@ -293,7 +293,7 @@ class TruncatedNormalOp : public XlaOpKernel {
|
|||||||
|
|
||||||
REGISTER_XLA_OP(Name("TruncatedNormal")
|
REGISTER_XLA_OP(Name("TruncatedNormal")
|
||||||
.CompileTimeConstantInput("shape")
|
.CompileTimeConstantInput("shape")
|
||||||
.TypeConstraint("dtype", DT_FLOAT),
|
.TypeConstraint("dtype", {DT_FLOAT, DT_DOUBLE}),
|
||||||
TruncatedNormalOp);
|
TruncatedNormalOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -42,63 +42,199 @@ const char kRetvalOp[] = "_Retval";
|
|||||||
|
|
||||||
const int kMaxCallDepth = 100;
|
const int kMaxCallDepth = 100;
|
||||||
|
|
||||||
|
Status AnalyzeResourceUsage(
|
||||||
|
const Graph* graph, const absl::optional<std::string>& function_name,
|
||||||
|
const int call_depth, const absl::flat_hash_set<int>& resource_arg_indices,
|
||||||
|
FunctionLibraryRuntime* lib_runtime,
|
||||||
|
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
|
||||||
|
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
|
||||||
|
source_to_path);
|
||||||
|
|
||||||
bool IsControlFlowV1Node(const Node* n) {
|
bool IsControlFlowV1Node(const Node* n) {
|
||||||
return (n->IsEnter() || n->IsExit() || n->IsSwitch() || n->IsMerge() ||
|
return (n->IsEnter() || n->IsExit() || n->IsSwitch() || n->IsMerge() ||
|
||||||
n->IsNextIteration());
|
n->IsNextIteration());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given an output edge, find the corresponding input edge if given edge is
|
|
||||||
// coming from a pass-through node. Otherwise, return nullptr.
|
|
||||||
StatusOr<const Edge*> WalkBackPassThroughEdge(const Edge* e) {
|
|
||||||
const Node* n = e->src();
|
|
||||||
|
|
||||||
if (n->IsIdentity()) {
|
|
||||||
const Edge* ret;
|
|
||||||
TF_RETURN_IF_ERROR(n->input_edge(0, &ret));
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (n->type_string() == kIdentityNOp) {
|
|
||||||
const Edge* ret;
|
|
||||||
TF_RETURN_IF_ERROR(n->input_edge(e->src_output(), &ret));
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reaching here means e is not coming from a pass through node, return empty
|
|
||||||
// vector to indicate we can no longer trace back.
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(ycao): Add this as Tensorflow Node method.
|
// TODO(ycao): Add this as Tensorflow Node method.
|
||||||
StatusOr<absl::InlinedVector<const Edge*, 1>> OutputEdgesByIndex(const Node* n,
|
StatusOr<absl::InlinedVector<const Edge*, 1>> OutputEdgesByIndex(const Node& n,
|
||||||
int idx) {
|
int idx) {
|
||||||
absl::InlinedVector<const Edge*, 1> res;
|
absl::InlinedVector<const Edge*, 1> res;
|
||||||
if (idx >= n->num_outputs()) {
|
if (idx >= n.num_outputs()) {
|
||||||
return errors::InvalidArgument("Invalid out_edge index: ", idx, ", Node ",
|
return errors::InvalidArgument("Invalid out_edge index: ", idx, ", Node ",
|
||||||
n->name(), " only has ", n->num_outputs(),
|
n.name(), " only has ", n.num_outputs(),
|
||||||
" outputs.");
|
" outputs.");
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const Edge* o : n->out_edges()) {
|
for (const Edge* o : n.out_edges()) {
|
||||||
if (o->src_output() == idx) res.emplace_back(o);
|
if (o->src_output() == idx) res.emplace_back(o);
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsStackOrTensorArraySource(const Node* n) {
|
bool IsStackOrTensorArraySource(const Node& n) {
|
||||||
const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n->type_string());
|
const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string());
|
||||||
|
|
||||||
if (!op_info) return false;
|
if (!op_info) return false;
|
||||||
if (op_info->resource_kind() != XlaResourceKind::kStack &&
|
if (op_info->resource_kind() != XlaResourceKind::kStack &&
|
||||||
op_info->resource_kind() != XlaResourceKind::kTensorArray)
|
op_info->resource_kind() != XlaResourceKind::kTensorArray)
|
||||||
return false;
|
return false;
|
||||||
return n->num_outputs() > 0 && n->output_type(0) == DataType::DT_RESOURCE;
|
return n.num_outputs() > 0 && n.output_type(0) == DataType::DT_RESOURCE;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PropagateFromStackOrTensorArraySourceOp(
|
||||||
|
const Node& n, const absl::optional<std::string>& function_name,
|
||||||
|
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
|
||||||
|
user_to_source) {
|
||||||
|
ResourceUsageAnalysis::NodeInfo src_node_info(function_name, n.name(),
|
||||||
|
n.type_string());
|
||||||
|
for (const Edge* o : n.out_edges()) {
|
||||||
|
if (o->IsControlEdge()) continue;
|
||||||
|
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
(*user_to_source)[o] = src_node_info;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PropagateFromArgOp(
|
||||||
|
const Node& n, const absl::optional<std::string>& function_name,
|
||||||
|
const absl::flat_hash_set<int>& resource_arg_indices,
|
||||||
|
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
|
||||||
|
user_to_source) {
|
||||||
|
TF_RET_CHECK(n.type_string() == kArgOp);
|
||||||
|
|
||||||
|
int index;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", &index));
|
||||||
|
if (!resource_arg_indices.contains(index)) return Status::OK();
|
||||||
|
|
||||||
|
TF_RET_CHECK(function_name.has_value())
|
||||||
|
<< "ResourceUsageAnalysis does not support analyzing _Arg nodes "
|
||||||
|
"carrying Stack/TensorArray resource in given graph unless they "
|
||||||
|
"are in function calls.";
|
||||||
|
|
||||||
|
const ResourceUsageAnalysis::NodeInfo src_node_info(function_name, n.name(),
|
||||||
|
n.type_string());
|
||||||
|
|
||||||
|
for (const Edge* o : n.out_edges()) {
|
||||||
|
if (o->IsControlEdge()) continue;
|
||||||
|
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
(*user_to_source)[o] = src_node_info;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PropagateThroughCallOp(
|
||||||
|
const Node& n, const absl::optional<std::string>& function_name,
|
||||||
|
const int call_depth, FunctionLibraryRuntime* lib_runtime,
|
||||||
|
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
|
||||||
|
user_to_source,
|
||||||
|
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
|
||||||
|
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
|
||||||
|
source_to_path) {
|
||||||
|
if (call_depth > kMaxCallDepth) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Function call stack in given graph is too deep, last function ",
|
||||||
|
"name is: ", function_name.value());
|
||||||
|
}
|
||||||
|
// resource_arg_indices_for_call contains all indices of the input
|
||||||
|
// arguments that carry Stack/TensorArray resource handles.
|
||||||
|
absl::flat_hash_set<int> resource_arg_indices_for_call;
|
||||||
|
for (const Edge* e : n.in_edges()) {
|
||||||
|
if (!user_to_source->contains(e)) continue;
|
||||||
|
resource_arg_indices_for_call.emplace(e->dst_input());
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::string_view called_function_name = n.type_string();
|
||||||
|
FunctionLibraryRuntime::Handle handle;
|
||||||
|
TF_RETURN_IF_ERROR(InstantiateFunctionCall(n.def(), lib_runtime, &handle));
|
||||||
|
auto release_handle_on_return = gtl::MakeCleanup(
|
||||||
|
[&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
|
||||||
|
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
|
||||||
|
|
||||||
|
// Recursively analyze called function for resource sources and users.
|
||||||
|
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
|
||||||
|
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>
|
||||||
|
called_function_source_to_path;
|
||||||
|
TF_RETURN_IF_ERROR(AnalyzeResourceUsage(
|
||||||
|
fbody->graph, absl::optional<std::string>(called_function_name),
|
||||||
|
call_depth + 1, resource_arg_indices_for_call, lib_runtime,
|
||||||
|
&called_function_source_to_path));
|
||||||
|
|
||||||
|
std::unordered_map<std::string, Node*> node_name_index =
|
||||||
|
fbody->graph->BuildNodeNameIndex();
|
||||||
|
|
||||||
|
for (auto it : called_function_source_to_path) {
|
||||||
|
ResourceUsageAnalysis::NodeInfo src_node_info = it.first;
|
||||||
|
|
||||||
|
// If source is an _Arg, then the true source is actually corresponding
|
||||||
|
// edge that feeds into function call node with the same index.
|
||||||
|
if (src_node_info.op_ == kArgOp) {
|
||||||
|
const Node* arg_src = node_name_index[src_node_info.node_name_];
|
||||||
|
int index;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(arg_src->attrs(), "index", &index));
|
||||||
|
|
||||||
|
const Edge* e;
|
||||||
|
TF_RETURN_IF_ERROR(n.input_edge(index, &e));
|
||||||
|
const Node* true_src = e->src();
|
||||||
|
src_node_info.function_name_ = function_name;
|
||||||
|
src_node_info.node_name_ = true_src->name();
|
||||||
|
src_node_info.op_ = true_src->type_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& dst_node_info : it.second) {
|
||||||
|
// If user is an _Retval, then the true user is actually corresponding
|
||||||
|
// edge of that _Retval.
|
||||||
|
if (dst_node_info.op_ == kRetvalOp) {
|
||||||
|
const Node* ret_user = node_name_index[dst_node_info.node_name_];
|
||||||
|
int index;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(ret_user->attrs(), "index", &index));
|
||||||
|
|
||||||
|
absl::InlinedVector<const Edge*, 1> outs;
|
||||||
|
TF_ASSIGN_OR_RETURN(outs, OutputEdgesByIndex(n, index));
|
||||||
|
for (const Edge* o : outs) (*user_to_source)[o] = src_node_info;
|
||||||
|
} else {
|
||||||
|
(*source_to_path)[src_node_info].emplace(dst_node_info);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Analyzes pass through values for Identity and IdentityN ops.
|
||||||
|
Status PropagateThroughIdentityOp(
|
||||||
|
const Node& n,
|
||||||
|
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
|
||||||
|
user_to_source) {
|
||||||
|
TF_RET_CHECK(n.IsIdentity() || n.type_string() == kIdentityNOp);
|
||||||
|
if (n.IsIdentity()) {
|
||||||
|
for (const Edge* o : n.out_edges()) {
|
||||||
|
if (o->IsControlEdge()) continue;
|
||||||
|
const Edge* in;
|
||||||
|
TF_RETURN_IF_ERROR(n.input_edge(0, &in));
|
||||||
|
if (!user_to_source->contains(in)) continue;
|
||||||
|
user_to_source->emplace(std::make_pair(o, (*user_to_source)[in]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (const Edge* o : n.out_edges()) {
|
||||||
|
if (o->IsControlEdge()) continue;
|
||||||
|
const Edge* in;
|
||||||
|
TF_RETURN_IF_ERROR(n.input_edge(o->src_output(), &in));
|
||||||
|
if (!user_to_source->contains(in)) continue;
|
||||||
|
user_to_source->emplace(std::make_pair(o, (*user_to_source)[in]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AnalyzeResourceUsage(
|
Status AnalyzeResourceUsage(
|
||||||
const Graph* graph, FunctionLibraryRuntime* lib_runtime,
|
const Graph* graph, const absl::optional<std::string>& function_name,
|
||||||
const absl::optional<std::string>& function_name, const int call_depth,
|
const int call_depth, const absl::flat_hash_set<int>& resource_arg_indices,
|
||||||
const absl::flat_hash_set<int>& resource_arg_indices,
|
FunctionLibraryRuntime* lib_runtime,
|
||||||
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
|
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
|
||||||
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
|
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
|
||||||
source_to_path) {
|
source_to_path) {
|
||||||
@ -127,120 +263,30 @@ Status AnalyzeResourceUsage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Record a resource source edge.
|
// Record a resource source edge.
|
||||||
if (IsStackOrTensorArraySource(n)) {
|
if (IsStackOrTensorArraySource(*n)) {
|
||||||
ResourceUsageAnalysis::NodeInfo src_node_info(function_name, n->name(),
|
PropagateFromStackOrTensorArraySourceOp(*n, function_name,
|
||||||
n->type_string());
|
&user_to_source);
|
||||||
for (const Edge* o : n->out_edges()) {
|
|
||||||
if (o->IsControlEdge()) continue;
|
|
||||||
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
user_to_source[o] = src_node_info;
|
|
||||||
}
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Arguments that are listed in resource_arg_indices are also considered as
|
// Arguments that are listed in resource_arg_indices are also considered as
|
||||||
// resource sources.
|
// resource sources.
|
||||||
if (n->IsArg()) {
|
if (n->IsArg()) {
|
||||||
int index;
|
TF_RETURN_IF_ERROR(PropagateFromArgOp(
|
||||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
*n, function_name, resource_arg_indices, &user_to_source));
|
||||||
if (!resource_arg_indices.contains(index)) continue;
|
|
||||||
|
|
||||||
TF_RET_CHECK(function_name.has_value())
|
|
||||||
<< "ResourceUsageAnalysis does not support analyzing _Arg nodes "
|
|
||||||
"carrying Stack/TensorArray resource in given graph unless they "
|
|
||||||
"are in function calls.";
|
|
||||||
|
|
||||||
const ResourceUsageAnalysis::NodeInfo src_node_info(
|
|
||||||
function_name, n->name(), n->type_string());
|
|
||||||
|
|
||||||
for (const Edge* o : n->out_edges()) {
|
|
||||||
if (o->IsControlEdge()) continue;
|
|
||||||
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
user_to_source[o] = src_node_info;
|
|
||||||
}
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Recursively analyze function call ops.
|
||||||
if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), *n)) {
|
if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), *n)) {
|
||||||
if (call_depth > kMaxCallDepth) {
|
TF_RETURN_IF_ERROR(PropagateThroughCallOp(*n, function_name, call_depth,
|
||||||
return errors::InvalidArgument(
|
lib_runtime, &user_to_source,
|
||||||
"Function call stack in given graph is too deep, last function ",
|
source_to_path));
|
||||||
"name is: ", function_name.value());
|
|
||||||
}
|
|
||||||
// resource_arg_indices_for_call contains all indices of the input
|
|
||||||
// arguments that carry Stack/TensorArray resource handles.
|
|
||||||
absl::flat_hash_set<int> resource_arg_indices_for_call;
|
|
||||||
for (const Edge* e : n->in_edges()) {
|
|
||||||
if (!user_to_source.contains(e)) continue;
|
|
||||||
resource_arg_indices_for_call.emplace(e->dst_input());
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::string_view called_function_name = n->type_string();
|
|
||||||
FunctionLibraryRuntime::Handle handle;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
InstantiateFunctionCall(n->def(), lib_runtime, &handle));
|
|
||||||
auto release_handle_on_return = gtl::MakeCleanup(
|
|
||||||
[&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
|
|
||||||
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
|
|
||||||
|
|
||||||
// Recursively analyze called function for resource sources and users.
|
|
||||||
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
|
|
||||||
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>
|
|
||||||
called_function_source_to_path;
|
|
||||||
TF_RETURN_IF_ERROR(AnalyzeResourceUsage(
|
|
||||||
fbody->graph, lib_runtime,
|
|
||||||
absl::optional<std::string>(called_function_name), call_depth + 1,
|
|
||||||
resource_arg_indices_for_call, &called_function_source_to_path));
|
|
||||||
|
|
||||||
std::unordered_map<std::string, Node*> node_name_index =
|
|
||||||
fbody->graph->BuildNodeNameIndex();
|
|
||||||
|
|
||||||
for (auto it : called_function_source_to_path) {
|
|
||||||
ResourceUsageAnalysis::NodeInfo src_node_info = it.first;
|
|
||||||
|
|
||||||
// If source is an _Arg, then the true source is actually corresponding
|
|
||||||
// edge that feeds into function call node with the same index.
|
|
||||||
if (src_node_info.op_ == kArgOp) {
|
|
||||||
const Node* arg_src = node_name_index[src_node_info.node_name_];
|
|
||||||
int index;
|
|
||||||
TF_RETURN_IF_ERROR(GetNodeAttr(arg_src->attrs(), "index", &index));
|
|
||||||
|
|
||||||
const Edge* e;
|
|
||||||
TF_RETURN_IF_ERROR(n->input_edge(index, &e));
|
|
||||||
const Node* true_src = e->src();
|
|
||||||
src_node_info.function_name_ = function_name;
|
|
||||||
src_node_info.node_name_ = true_src->name();
|
|
||||||
src_node_info.op_ = true_src->type_string();
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto& dst_node_info : it.second) {
|
|
||||||
// If user is an _Retval, then the true user is actually corresponding
|
|
||||||
// edge of that _Retval.
|
|
||||||
if (dst_node_info.op_ == kRetvalOp) {
|
|
||||||
const Node* ret_user = node_name_index[dst_node_info.node_name_];
|
|
||||||
int index;
|
|
||||||
TF_RETURN_IF_ERROR(GetNodeAttr(ret_user->attrs(), "index", &index));
|
|
||||||
|
|
||||||
absl::InlinedVector<const Edge*, 1> outs;
|
|
||||||
TF_ASSIGN_OR_RETURN(outs, OutputEdgesByIndex(n, index));
|
|
||||||
for (const Edge* o : outs) user_to_source[o] = src_node_info;
|
|
||||||
} else {
|
|
||||||
(*source_to_path)[src_node_info].emplace(dst_node_info);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const Edge* o : n->out_edges()) {
|
if (n->IsIdentity() || n->type_string() == kIdentityNOp) {
|
||||||
if (o->IsControlEdge()) continue;
|
TF_RETURN_IF_ERROR(PropagateThroughIdentityOp(*n, &user_to_source));
|
||||||
TF_ASSIGN_OR_RETURN(const Edge* e, WalkBackPassThroughEdge(o));
|
|
||||||
if (!e || !user_to_source.contains(e)) continue;
|
|
||||||
user_to_source.emplace(std::make_pair(o, user_to_source[e]));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,8 +306,9 @@ Status AnalyzeResourceUsage(
|
|||||||
absl::flat_hash_map<NodeInfo, absl::flat_hash_set<NodeInfo>>*
|
absl::flat_hash_map<NodeInfo, absl::flat_hash_set<NodeInfo>>*
|
||||||
source_to_path) {
|
source_to_path) {
|
||||||
return AnalyzeResourceUsage(
|
return AnalyzeResourceUsage(
|
||||||
graph, lib_runtime, /*function_name=*/{}, /*call_depth=*/0,
|
graph, /*function_name=*/{}, /*call_depth=*/0,
|
||||||
/*resource_arg_indices=*/absl::flat_hash_set<int>(), source_to_path);
|
/*resource_arg_indices=*/absl::flat_hash_set<int>(), lib_runtime,
|
||||||
|
source_to_path);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -1367,12 +1367,12 @@ For a more intuitive description, see the "Informal Description" section below.
|
|||||||
: : : detailed description. :
|
: : : detailed description. :
|
||||||
| `offset_dims` | `ArraySlice<int64>` | The set of dimensions in the |
|
| `offset_dims` | `ArraySlice<int64>` | The set of dimensions in the |
|
||||||
: : : output shape that offset into :
|
: : : output shape that offset into :
|
||||||
: : : a array sliced from operand. :
|
: : : an array sliced from operand. :
|
||||||
| `slice_sizes` | `ArraySlice<int64>` | `slice_sizes[i]` is the |
|
| `slice_sizes` | `ArraySlice<int64>` | `slice_sizes[i]` is the |
|
||||||
: : : bounds for the slice on :
|
: : : bounds for the slice on :
|
||||||
: : : dimension `i`. :
|
: : : dimension `i`. :
|
||||||
| `collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each |
|
| `collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each |
|
||||||
: : : \: slice that are collapsed :
|
: : : slice that are collapsed :
|
||||||
: : : away. These dimensions must :
|
: : : away. These dimensions must :
|
||||||
: : : have size 1. :
|
: : : have size 1. :
|
||||||
| `start_index_map` | `ArraySlice<int64>` | A map that describes how to |
|
| `start_index_map` | `ArraySlice<int64>` | A map that describes how to |
|
||||||
@ -1383,8 +1383,11 @@ For a more intuitive description, see the "Informal Description" section below.
|
|||||||
For convenience, we label dimensions in the output array not in `offset_dims`
|
For convenience, we label dimensions in the output array not in `offset_dims`
|
||||||
as `batch_dims`.
|
as `batch_dims`.
|
||||||
|
|
||||||
The output is an array of rank `batch_dims.size` + `operand.rank` -
|
The output is an array of rank `batch_dims.size` + `offset_dims.size`.
|
||||||
`collapsed_slice_dims`.size.
|
|
||||||
|
The `operand.rank` must equal the sume of `offset_dims.size` and
|
||||||
|
`collapsed_slice_dims`. Also, `slice_sizes.size` has to be equal to
|
||||||
|
`operand.rank`.
|
||||||
|
|
||||||
If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider
|
If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider
|
||||||
`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of
|
`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of
|
||||||
@ -1405,61 +1408,65 @@ accounting for `collapsed_slice_dims` (i.e. we pick
|
|||||||
`adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes`
|
`adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes`
|
||||||
with the bounds at indices `collapsed_slice_dims` removed).
|
with the bounds at indices `collapsed_slice_dims` removed).
|
||||||
|
|
||||||
Formally, the operand index `In` corresponding to an output index `Out` is
|
Formally, the operand index `In` corresponding to a given output index `Out` is
|
||||||
computed as follows:
|
calculated as follows:
|
||||||
|
|
||||||
1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out
|
1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out a
|
||||||
vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where
|
vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where
|
||||||
Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
|
Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
|
||||||
this is well defined even if `G` is empty -- if `G` is empty then `S` =
|
this is well defined even if `G` is empty -- if `G` is empty then `S` =
|
||||||
`start_indices`.
|
`start_indices`.
|
||||||
|
|
||||||
2. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by
|
2. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by
|
||||||
scattering `S` using `start_index_map`. More precisely:
|
scattering `S` using `start_index_map`. More precisely:
|
||||||
1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
|
|
||||||
`start_index_map.size`.
|
|
||||||
2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
|
|
||||||
|
|
||||||
3. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices
|
1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
|
||||||
at the offset dimensions in `Out` according to the `collapsed_slice_dims`
|
`start_index_map.size`.
|
||||||
set. More precisely:
|
|
||||||
1. `O`<sub>`in`</sub>[`expand_offset_dims`(`k`)] =
|
|
||||||
`Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
|
|
||||||
(`expand_offset_dims` is defined below).
|
|
||||||
2. `O`<sub>`in`</sub>[`_`] = `0` otherwise.
|
|
||||||
4. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
|
|
||||||
addition.
|
|
||||||
|
|
||||||
`expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`)
|
2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
|
||||||
and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g.,
|
|
||||||
|
3. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices
|
||||||
|
at the offset dimensions in `Out` according to the `collapsed_slice_dims`
|
||||||
|
set. More precisely:
|
||||||
|
|
||||||
|
1. `O`<sub>`in`</sub>[`remapped_offset_dims`(`k`)] =
|
||||||
|
`Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
|
||||||
|
(`remapped_offset_dims` is defined below).
|
||||||
|
|
||||||
|
2. `O`<sub>`in`</sub>[`_`] = `0` otherwise.
|
||||||
|
|
||||||
|
4. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
|
||||||
|
addition.
|
||||||
|
|
||||||
|
`remapped_offset_dims` is a monotonic function with domain [`0`, `offset.size`)
|
||||||
|
and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g.,
|
||||||
`offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`,
|
`offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`,
|
||||||
`2`} then `expand_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
|
`2`} then `remapped_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
|
||||||
|
|
||||||
### Informal Description and Examples
|
### Informal Description and Examples
|
||||||
|
|
||||||
Informally, every index `Out` in the output array corresponds to an element `E`
|
Informally, every index `Out` in the output array corresponds to an element `E`
|
||||||
in the operand array, computed as follows:
|
in the operand array, computed as follows:
|
||||||
|
|
||||||
- We use the batch dimensions in `Out` to look up a starting index from
|
- We use the batch dimensions in `Out` to look up a starting index from
|
||||||
`start_indices`.
|
`start_indices`.
|
||||||
|
|
||||||
- We use `start_index_map` to map the starting index (which may have size less
|
- We use `start_index_map` to map the starting index (whose size may be less
|
||||||
than operand.rank) to a "full" starting index into operand.
|
than operand.rank) to a "full" starting index into the `operand`.
|
||||||
|
|
||||||
- We dynamic-slice out a slice with size `slice_sizes` using the full starting
|
- We dynamic-slice out a slice with size `slice_sizes` using the full starting
|
||||||
index.
|
index.
|
||||||
|
|
||||||
- We reshape the slice by collapsing the `collapsed_slice_dims` dimensions.
|
- We reshape the slice by collapsing the `collapsed_slice_dims` dimensions.
|
||||||
Since all collapsed slice dimensions have to have bound 1 this reshape is
|
Since all collapsed slice dimensions must have a bound of 1, this reshape is
|
||||||
always legal.
|
always legal.
|
||||||
|
|
||||||
- We use the offset dimensions in `Out` to index into this slice to get the
|
- We use the offset dimensions in `Out` to index into this slice to get the
|
||||||
input element, `E`, corresponding to output index `Out`.
|
input element, `E`, corresponding to output index `Out`.
|
||||||
|
|
||||||
`index_vector_dim` is set to `start_indices.rank` - `1` in all of the
|
`index_vector_dim` is set to `start_indices.rank` - `1` in all of the examples
|
||||||
examples that follow. More interesting values for `index_vector_dim` does not
|
that follow. More interesting values for `index_vector_dim` do not change the
|
||||||
change the operation fundamentally, but makes the visual representation more
|
operation fundamentally, but make the visual representation more cumbersome.
|
||||||
cumbersome.
|
|
||||||
|
|
||||||
To get an intuition on how all of the above fits together, let's look at an
|
To get an intuition on how all of the above fits together, let's look at an
|
||||||
example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array. The
|
example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array. The
|
||||||
@ -1526,12 +1533,12 @@ As a final example, we use (2) and (3) to implement `tf.gather_nd`:
|
|||||||
|
|
||||||
`G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index
|
`G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index
|
||||||
from the gather indices array as usual, except the starting index has only one
|
from the gather indices array as usual, except the starting index has only one
|
||||||
element, `X`. Similarly, there is only one output offset index with the value
|
element, `X`. Similarly, there is only one output offset index with the value
|
||||||
`O`<sub>`0`</sub>. However, before being used as indices into the input array,
|
`O`<sub>`0`</sub>. However, before being used as indices into the input array,
|
||||||
these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in
|
these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in
|
||||||
the formal description) and "Offset Mapping" (`expand_offset_dims` in the formal
|
the formal description) and "Offset Mapping" (`remapped_offset_dims` in the
|
||||||
description) into [`X`,`0`] and [`0`,`O`<sub>`0`</sub>] respectively, adding up
|
formal description) into [`X`,`0`] and [`0`,`O`<sub>`0`</sub>] respectively,
|
||||||
to [`X`,`O`<sub>`0`</sub>]. In other words, the output index
|
adding up to [`X`,`O`<sub>`0`</sub>]. In other words, the output index
|
||||||
[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index
|
[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index
|
||||||
[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us
|
[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us
|
||||||
the semantics for `tf.gather_nd`.
|
the semantics for `tf.gather_nd`.
|
||||||
|
@ -104,11 +104,49 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "event_pool",
|
||||||
|
srcs = ["event_pool.cc"],
|
||||||
|
hdrs = ["event_pool.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/core:stream_executor",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "semaphore",
|
||||||
|
srcs = ["semaphore.cc"],
|
||||||
|
hdrs = ["semaphore.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "semaphore_test",
|
||||||
|
srcs = ["semaphore_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":semaphore",
|
||||||
|
"//tensorflow/compiler/xla:test",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "shared_device_buffer",
|
name = "shared_device_buffer",
|
||||||
srcs = ["shared_device_buffer.cc"],
|
srcs = ["shared_device_buffer.cc"],
|
||||||
hdrs = ["shared_device_buffer.h"],
|
hdrs = ["shared_device_buffer.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":event_pool",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||||
@ -132,6 +170,23 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "device",
|
||||||
|
srcs = ["device.cc"],
|
||||||
|
hdrs = ["device.h"],
|
||||||
|
deps = [
|
||||||
|
":event_pool",
|
||||||
|
":semaphore",
|
||||||
|
":worker_thread",
|
||||||
|
"//tensorflow/compiler/xla:status",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:stream_executor",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "local_client",
|
name = "local_client",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -149,9 +204,9 @@ cc_library(
|
|||||||
],
|
],
|
||||||
features = ["-use_header_modules"],
|
features = ["-use_header_modules"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":device",
|
||||||
":shared_device_buffer",
|
":shared_device_buffer",
|
||||||
":types",
|
":types",
|
||||||
":worker_thread",
|
|
||||||
"//tensorflow/compiler/xla:executable_run_options",
|
"//tensorflow/compiler/xla:executable_run_options",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
@ -165,7 +220,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/client:local_client",
|
"//tensorflow/compiler/xla/client:local_client",
|
||||||
"//tensorflow/compiler/xla/client:xla_computation",
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
"//tensorflow/compiler/xla/service:computation_placer",
|
"//tensorflow/compiler/xla/service:computation_placer",
|
||||||
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
|
||||||
"//tensorflow/compiler/xla/service:platform_util",
|
"//tensorflow/compiler/xla/service:platform_util",
|
||||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||||
"//tensorflow/core:bfc_allocator",
|
"//tensorflow/core:bfc_allocator",
|
||||||
@ -199,10 +253,11 @@ tf_pybind_extension(
|
|||||||
":local_client",
|
":local_client",
|
||||||
":shared_device_buffer",
|
":shared_device_buffer",
|
||||||
":types",
|
":types",
|
||||||
":worker_thread",
|
|
||||||
":xrt",
|
":xrt",
|
||||||
|
"@com_google_absl//absl/base",
|
||||||
"@com_google_absl//absl/hash",
|
"@com_google_absl//absl/hash",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
@ -225,6 +280,7 @@ tf_pybind_extension(
|
|||||||
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
|
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
|
||||||
"//tensorflow/compiler/xla/client/lib:svd",
|
"//tensorflow/compiler/xla/client/lib:svd",
|
||||||
"//tensorflow/compiler/xla/service:computation_placer",
|
"//tensorflow/compiler/xla/service:computation_placer",
|
||||||
|
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||||
|
100
tensorflow/compiler/xla/python/device.cc
Normal file
100
tensorflow/compiler/xla/python/device.cc
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/python/device.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
Device::Device(se::StreamExecutor* executor, bool synchronous_deallocation,
|
||||||
|
bool asynchronous, bool allow_event_reuse)
|
||||||
|
: synchronous_deallocation_(synchronous_deallocation),
|
||||||
|
event_pool_(allow_event_reuse),
|
||||||
|
compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1) {
|
||||||
|
compute_stream_ = absl::make_unique<se::Stream>(executor);
|
||||||
|
host_to_device_stream_ = absl::make_unique<se::Stream>(executor);
|
||||||
|
device_to_host_stream_ = absl::make_unique<se::Stream>(executor);
|
||||||
|
callback_stream_ = absl::make_unique<se::Stream>(executor);
|
||||||
|
compute_stream_->Init();
|
||||||
|
host_to_device_stream_->Init();
|
||||||
|
device_to_host_stream_->Init();
|
||||||
|
callback_stream_->Init();
|
||||||
|
device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
|
||||||
|
for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
|
||||||
|
auto stream = absl::make_unique<se::Stream>(executor);
|
||||||
|
stream->Init();
|
||||||
|
device_to_device_streams_.push_back(std::move(stream));
|
||||||
|
}
|
||||||
|
execute_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
|
||||||
|
"py_xla_execute");
|
||||||
|
callback_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
|
||||||
|
"py_xla_callback");
|
||||||
|
}
|
||||||
|
|
||||||
|
Device::~Device() {
|
||||||
|
Status status = SynchronizeAllActivity();
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << "Error when closing device: " << status;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Device::SynchronizeAllActivity() {
|
||||||
|
Status status;
|
||||||
|
// TODO(phawkins): in theory the call to SynchronizeAllActivity below should
|
||||||
|
// suffice. However on the Host platform SynchronizeAllActivity is a dummy
|
||||||
|
// implementation that doesn't actually block. To make sure activity has
|
||||||
|
// stopped, also block on the compute stream. If SynchronizeAllActivity is
|
||||||
|
// fixed, we could remove the BlockHostUntilDone call.
|
||||||
|
status.Update(compute_stream_->BlockHostUntilDone());
|
||||||
|
bool ok = compute_stream_->parent()->SynchronizeAllActivity();
|
||||||
|
if (!ok) {
|
||||||
|
status.Update(Unknown("SynchronizeAllActivity failed."));
|
||||||
|
}
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Device::ThenMemcpyDeviceToDevice(se::Stream* src_stream,
|
||||||
|
se::Stream* dst_stream,
|
||||||
|
se::DeviceMemoryBase src_buffer,
|
||||||
|
se::DeviceMemoryBase dst_buffer) {
|
||||||
|
// The default implementation simply calls ThenMemcpyD2D, and assumes that
|
||||||
|
// the buffer addresses identify the devices. This does not work
|
||||||
|
// on all platforms; this method is virtual so it can be overridden.
|
||||||
|
src_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::ThenExecuteOnCallbackThread(se::Stream* stream,
|
||||||
|
std::function<void()> callback) const {
|
||||||
|
stream->ThenDoHostCallback([this, callback]() mutable {
|
||||||
|
callback_thread_->Schedule(std::move(callback));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
se::Stream* Device::GetDeviceToDeviceStream() {
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
int i = next_device_to_device_stream_;
|
||||||
|
next_device_to_device_stream_ =
|
||||||
|
(next_device_to_device_stream_ + 1) % device_to_device_streams_.size();
|
||||||
|
return device_to_device_streams_.at(i).get();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
134
tensorflow/compiler/xla/python/device.h
Normal file
134
tensorflow/compiler/xla/python/device.h
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "tensorflow/compiler/xla/python/event_pool.h"
|
||||||
|
#include "tensorflow/compiler/xla/python/semaphore.h"
|
||||||
|
#include "tensorflow/compiler/xla/python/worker_thread.h"
|
||||||
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// Class that encapsulates state relating to a device (e.g., a GPU) on which we
|
||||||
|
// can perform computation and transfers.
|
||||||
|
class Device {
|
||||||
|
public:
|
||||||
|
// If synchronous_deallocation is true, the host must not free buffers until
|
||||||
|
// compute/transfers that use those buffers have completed. For example, this
|
||||||
|
// typically is the case for the "platform" where compute/transfers are
|
||||||
|
// operations that take place on another thread.
|
||||||
|
//
|
||||||
|
// If asynchronous is false, the host will synchronize to the device after
|
||||||
|
// each execution or transfer. This is intended for debugging only.
|
||||||
|
Device(se::StreamExecutor* executor, bool synchronous_deallocation,
|
||||||
|
bool asynchronous, bool allow_event_reuse);
|
||||||
|
virtual ~Device();
|
||||||
|
|
||||||
|
bool synchronous_deallocation() const { return synchronous_deallocation_; }
|
||||||
|
|
||||||
|
EventPool& event_pool() { return event_pool_; }
|
||||||
|
|
||||||
|
se::Stream* compute_stream() const { return compute_stream_.get(); }
|
||||||
|
se::Stream* host_to_device_stream() const {
|
||||||
|
return host_to_device_stream_.get();
|
||||||
|
}
|
||||||
|
se::Stream* device_to_host_stream() const {
|
||||||
|
return device_to_host_stream_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a device to device stream. Allocates streams in a round-robin
|
||||||
|
// fashion amongst the available streams.
|
||||||
|
se::Stream* GetDeviceToDeviceStream();
|
||||||
|
|
||||||
|
// Enqueues a copy of `src_buffer` to `dst_buffer` onto `src_stream`.
|
||||||
|
virtual Status ThenMemcpyDeviceToDevice(se::Stream* src_stream,
|
||||||
|
se::Stream* dst_stream,
|
||||||
|
se::DeviceMemoryBase src_buffer,
|
||||||
|
se::DeviceMemoryBase dst_buffer);
|
||||||
|
|
||||||
|
WorkerThread* execute_thread() const { return execute_thread_.get(); }
|
||||||
|
|
||||||
|
// Enqueues a host callback on 'stream', to be executed by callback_thread_.
|
||||||
|
// ThenDoHostCallback is often constrained in what it can do, in particular,
|
||||||
|
// on GPU the callback runs on a thread belonging to the GPU runtime and
|
||||||
|
// cannot perform GPU operations itself.
|
||||||
|
void ThenExecuteOnCallbackThread(se::Stream* stream,
|
||||||
|
std::function<void()> callback) const;
|
||||||
|
|
||||||
|
// Helpers for releasing values on a worker thread at the tail of a stream on
|
||||||
|
// a worker thread. Copies `object`, and destroys the copy when the tail of
|
||||||
|
// the stream is reached. The destruction happens either in the caller's
|
||||||
|
// thread or on the worker thread (depending on thread schedules), not a
|
||||||
|
// device callback, so it is safe if the destructor frees device resource
|
||||||
|
// (e.g., GPU objects).
|
||||||
|
// TODO(phawkins): use move-capture when we can use C++14 features.
|
||||||
|
template <typename T>
|
||||||
|
void ThenRelease(se::Stream* stream, T object) const {
|
||||||
|
if (callback_stream_.get() != stream) {
|
||||||
|
callback_stream_->ThenWaitFor(stream);
|
||||||
|
}
|
||||||
|
ThenExecuteOnCallbackThread(callback_stream_.get(),
|
||||||
|
[object]() { /* releases object */ });
|
||||||
|
}
|
||||||
|
|
||||||
|
Semaphore& compute_semaphore() { return compute_semaphore_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status SynchronizeAllActivity();
|
||||||
|
|
||||||
|
bool synchronous_deallocation_;
|
||||||
|
|
||||||
|
EventPool event_pool_;
|
||||||
|
|
||||||
|
// Semaphore used to limit how many programs can be enqueued on the compute
|
||||||
|
// stream by the host ahead of the device.
|
||||||
|
Semaphore compute_semaphore_;
|
||||||
|
|
||||||
|
std::unique_ptr<se::Stream> compute_stream_;
|
||||||
|
std::unique_ptr<se::Stream> host_to_device_stream_;
|
||||||
|
std::unique_ptr<se::Stream> device_to_host_stream_;
|
||||||
|
std::vector<std::unique_ptr<se::Stream>> device_to_device_streams_;
|
||||||
|
|
||||||
|
// Number of device-to-device streams to create in the multistream case.
|
||||||
|
static constexpr int kNumDeviceToDeviceStreams = 4;
|
||||||
|
|
||||||
|
absl::Mutex mu_;
|
||||||
|
int next_device_to_device_stream_ GUARDED_BY(mu_) = 0;
|
||||||
|
|
||||||
|
// Callback stream is used for running short host-side callbacks after device
|
||||||
|
// side events, without preventing the device-side stream from doing useful
|
||||||
|
// work.
|
||||||
|
std::unique_ptr<se::Stream> callback_stream_;
|
||||||
|
|
||||||
|
// A worker thread, used for replicated computation launches.
|
||||||
|
std::unique_ptr<WorkerThread> execute_thread_;
|
||||||
|
|
||||||
|
// A worker thread, used for callbacks. It is necessary that this be a
|
||||||
|
// different thread to the execute thread because we acquire the compute
|
||||||
|
// semaphore during calls to Execute but release it from a callback and if
|
||||||
|
// they are the same thread we might deadlock.
|
||||||
|
std::unique_ptr<WorkerThread> callback_thread_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_H_
|
52
tensorflow/compiler/xla/python/event_pool.cc
Normal file
52
tensorflow/compiler/xla/python/event_pool.cc
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/python/event_pool.h"
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
EventPool::Handle::~Handle() {
|
||||||
|
if (pool_ && event_) {
|
||||||
|
absl::MutexLock lock(&pool_->mu_);
|
||||||
|
pool_->free_events_.push(std::move(event_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EventPool::EventPool(bool allow_reuse) : allow_reuse_(allow_reuse) {}
|
||||||
|
|
||||||
|
StatusOr<EventPool::Handle> EventPool::ThenAllocateAndRecordEvent(
|
||||||
|
se::Stream* stream) {
|
||||||
|
Handle event;
|
||||||
|
|
||||||
|
if (allow_reuse_) {
|
||||||
|
event.pool_ = this;
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
if (!free_events_.empty()) {
|
||||||
|
event.event_ = std::move(free_events_.top());
|
||||||
|
free_events_.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!event.event_) {
|
||||||
|
event.event_ = absl::make_unique<se::Event>(stream->parent());
|
||||||
|
TF_RET_CHECK(event.event_->Init()) << "Event initialization failed";
|
||||||
|
}
|
||||||
|
stream->ThenRecordEvent(event.event_.get());
|
||||||
|
return event;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
76
tensorflow/compiler/xla/python/event_pool.h
Normal file
76
tensorflow/compiler/xla/python/event_pool.h
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <stack>
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
class EventPool {
|
||||||
|
public:
|
||||||
|
class Handle {
|
||||||
|
public:
|
||||||
|
Handle() = default;
|
||||||
|
~Handle();
|
||||||
|
|
||||||
|
Handle(const Handle&) = delete;
|
||||||
|
Handle(Handle&&) = default;
|
||||||
|
Handle& operator=(const Handle&) = delete;
|
||||||
|
Handle& operator=(Handle&&) = default;
|
||||||
|
|
||||||
|
se::Event* event() const { return event_.get(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class EventPool;
|
||||||
|
|
||||||
|
EventPool* pool_ = nullptr;
|
||||||
|
std::unique_ptr<se::Event> event_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initializes a new EventPool. If `allow_reuse` is true, then events will be
|
||||||
|
// returned to the pool when their handles are deleted and made available to
|
||||||
|
// subsequent allocations. Reuse only works on the GPU platform.
|
||||||
|
explicit EventPool(bool allow_reuse);
|
||||||
|
|
||||||
|
// Allocates a new (or reused) event from the pool, and records the event on
|
||||||
|
// `stream`.
|
||||||
|
//
|
||||||
|
// Reuse is only possible on GPU. Event allocation and recording are coupled
|
||||||
|
// in a single operation because on GPU it is recording an event that makes it
|
||||||
|
// a "new" event. According to the CUDA documentation it is safe to call
|
||||||
|
// cudaEventRecord even if that event may still be in use on the device; APIs
|
||||||
|
// such as cudaStreamWaitEvent capture the state of the event at the time of
|
||||||
|
// the host-side call and are not affected by a later host-side
|
||||||
|
// cudaEventRecord.
|
||||||
|
StatusOr<Handle> ThenAllocateAndRecordEvent(se::Stream* stream);
|
||||||
|
|
||||||
|
private:
|
||||||
|
const bool allow_reuse_;
|
||||||
|
|
||||||
|
absl::Mutex mu_;
|
||||||
|
std::stack<std::unique_ptr<se::Event>> free_events_ GUARDED_BY(mu_);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_
|
@ -18,22 +18,39 @@ limitations under the License.
|
|||||||
// Asynchronous execution:
|
// Asynchronous execution:
|
||||||
// -----------------------
|
// -----------------------
|
||||||
//
|
//
|
||||||
// If 'asynchronous' is set when constructing the client, computations and
|
// Computations and host-to-device transfers do not need to block the host
|
||||||
// host-to-device transfers do not block the host waiting for the operation to
|
// waiting for the operation to complete but instead return control to the host
|
||||||
// complete but instead return control to the host immediately. This allows
|
// immediately. This allows Python logic to overlap with device-side
|
||||||
// Python logic to overlap with device-side computation.
|
// computation.
|
||||||
//
|
//
|
||||||
// For a good user experience, we must be careful only to enqueue operations
|
// For a good user experience, we must be careful only to enqueue operations
|
||||||
// that are unlikely to fail; as a rule error checking must be done eagerly
|
// that are unlikely to fail; as a rule error checking must be done eagerly
|
||||||
// before returning control to the client.
|
// before returning control to the client.
|
||||||
//
|
//
|
||||||
|
// The degree to which the client can enqueue operations ahead of the client
|
||||||
|
// is limited by a semaphore. There are at two modes: asynchronous, where we
|
||||||
|
// allow the client to enqueue up to 32 executions ahead of the device, and
|
||||||
|
// synchronous, where we limit the client to having one enqueued operation at
|
||||||
|
// a time. The value of 32 is arbitrary.
|
||||||
|
//
|
||||||
|
// Even in asynchronous mode, it is important that we do not permit
|
||||||
|
// unbounded queue-ahead. Firstly it is problematic when the user does something
|
||||||
|
// like the following in Python:
|
||||||
|
// %timeit run_computation()
|
||||||
|
// To the timeit logic, op() appears to be extremely cheap since it is deferring
|
||||||
|
// all of its real work and not blocking, and so the %timeit will run op() many
|
||||||
|
// (e.g., 10000) times to get better timing resolution, even though in reality
|
||||||
|
// it may be expensive. Secondly, on CPU the allocator is synchronized with the
|
||||||
|
// head of the compute stream, and we allocate buffers for all of the enqueued
|
||||||
|
// programs without any reuse (unlike GPU). This means that the memory usage
|
||||||
|
// is proportional to the queue size.
|
||||||
|
//
|
||||||
// Multi-stream execution:
|
// Multi-stream execution:
|
||||||
// -----------------------
|
// -----------------------
|
||||||
//
|
//
|
||||||
// On certain platforms (e.g., TPU), we use a multistream execution design,
|
// We use a multistream execution design, where different Streams are used for
|
||||||
// where different Streams are used for host-to-device transfers,
|
// host-to-device transfers, device-to-host transfers, and compute. This allows
|
||||||
// device-to-host transfers, and compute. This allows us to overlap transfers on
|
// us to overlap transfers on and off the device with computation.
|
||||||
// and off the device with computation.
|
|
||||||
//
|
//
|
||||||
// Synchronization between streams occurs via BufferDefinitionEvents that
|
// Synchronization between streams occurs via BufferDefinitionEvents that
|
||||||
// describe when the contents of a logical buffer are known to be valid on
|
// describe when the contents of a logical buffer are known to be valid on
|
||||||
@ -66,9 +83,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/synchronization/blocking_counter.h"
|
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "absl/synchronization/notification.h"
|
|
||||||
#include "absl/time/time.h"
|
#include "absl/time/time.h"
|
||||||
#include "include/pybind11/pybind11.h"
|
#include "include/pybind11/pybind11.h"
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
@ -78,12 +93,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
|
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/python/types.h"
|
#include "tensorflow/compiler/xla/python/types.h"
|
||||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/common_runtime/bfc_allocator.h"
|
#include "tensorflow/core/common_runtime/bfc_allocator.h"
|
||||||
|
#include "tensorflow/core/common_runtime/gpu/gpu_host_allocator.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
@ -93,98 +108,6 @@ namespace xla {
|
|||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
// Registers a 'fn_capsule' as a CPU custom call target.
|
|
||||||
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
|
|
||||||
// "xla._CPU_CUSTOM_CALL_TARGET".
|
|
||||||
Status RegisterCpuCustomCallTarget(const std::string& fn_name,
|
|
||||||
py::capsule capsule) {
|
|
||||||
static const char* const kName = "xla._CPU_CUSTOM_CALL_TARGET";
|
|
||||||
if (absl::string_view(capsule.name()) != kName) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"Argument to RegisterCpuCustomCallTargetRegistry was not a "
|
|
||||||
"xla._CPU_CUSTOM_CALL_TARGET capsule.");
|
|
||||||
}
|
|
||||||
CustomCallTargetRegistry::Global()->Register(
|
|
||||||
fn_name, static_cast<void*>(capsule), "Host");
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Device::Device(se::StreamExecutor* executor, bool use_multiple_streams,
|
|
||||||
bool synchronous_deallocation, bool asynchronous)
|
|
||||||
: use_multiple_streams_(use_multiple_streams),
|
|
||||||
synchronous_deallocation_(synchronous_deallocation),
|
|
||||||
asynchronous_(asynchronous) {
|
|
||||||
compute_stream_ = std::make_shared<se::Stream>(executor);
|
|
||||||
compute_stream_->Init();
|
|
||||||
if (use_multiple_streams) {
|
|
||||||
host_to_device_stream_ = std::make_shared<se::Stream>(executor);
|
|
||||||
device_to_host_stream_ = std::make_shared<se::Stream>(executor);
|
|
||||||
callback_stream_ = std::make_shared<se::Stream>(executor);
|
|
||||||
host_to_device_stream_->Init();
|
|
||||||
device_to_host_stream_->Init();
|
|
||||||
callback_stream_->Init();
|
|
||||||
device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
|
|
||||||
for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
|
|
||||||
auto stream = std::make_shared<se::Stream>(executor);
|
|
||||||
stream->Init();
|
|
||||||
device_to_device_streams_.push_back(std::move(stream));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
callback_stream_ = host_to_device_stream_ = device_to_host_stream_ =
|
|
||||||
compute_stream_;
|
|
||||||
device_to_device_streams_.push_back(compute_stream_);
|
|
||||||
}
|
|
||||||
worker_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
|
|
||||||
"py_xla_execute");
|
|
||||||
}
|
|
||||||
|
|
||||||
Device::~Device() {
|
|
||||||
Status status = SynchronizeAllActivity();
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << "Error when closing device: " << status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Device::SynchronizeAllActivity() {
|
|
||||||
Status status;
|
|
||||||
// TODO(phawkins): in theory the call to SynchronizeAllActivity below should
|
|
||||||
// suffice. However on the Host platform SynchronizeAllActivity is a dummy
|
|
||||||
// implementation that doesn't actually block. To make sure activity has
|
|
||||||
// stopped, also block on the compute stream. If SynchronizeAllActivity is
|
|
||||||
// fixed, we could remove the BlockHostUntilDone call.
|
|
||||||
status.Update(compute_stream_->BlockHostUntilDone());
|
|
||||||
bool ok = compute_stream_->parent()->SynchronizeAllActivity();
|
|
||||||
if (!ok) {
|
|
||||||
status.Update(Unknown("SynchronizeAllActivity failed."));
|
|
||||||
}
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Device::ThenMemcpyDeviceToDevice(se::Stream* src_stream,
|
|
||||||
se::Stream* dst_stream,
|
|
||||||
se::DeviceMemoryBase src_buffer,
|
|
||||||
se::DeviceMemoryBase dst_buffer) {
|
|
||||||
// The default implementation simply calls ThenMemcpyD2D, and assumes that
|
|
||||||
// the buffer addresses identify the devices. This does not work
|
|
||||||
// on all platforms; this method is virtual so it can be overridden.
|
|
||||||
src_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size());
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Device::ThenExecuteOnWorkerThread(se::Stream* stream,
|
|
||||||
std::function<void()> callback) const {
|
|
||||||
stream->ThenDoHostCallback(
|
|
||||||
[this, callback]() { worker_thread_->Schedule(std::move(callback)); });
|
|
||||||
}
|
|
||||||
|
|
||||||
se::Stream* Device::GetDeviceToDeviceStream() {
|
|
||||||
absl::MutexLock lock(&mu_);
|
|
||||||
int i = next_device_to_device_stream_;
|
|
||||||
next_device_to_device_stream_ =
|
|
||||||
(next_device_to_device_stream_ + 1) % device_to_device_streams_.size();
|
|
||||||
return device_to_device_streams_.at(i).get();
|
|
||||||
}
|
|
||||||
|
|
||||||
static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
|
static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
|
||||||
se::Platform* platform, LocalClient* client, double memory_fraction,
|
se::Platform* platform, LocalClient* client, double memory_fraction,
|
||||||
bool preallocate) {
|
bool preallocate) {
|
||||||
@ -237,44 +160,57 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
|
|||||||
options.set_platform(platform);
|
options.set_platform(platform);
|
||||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||||
ClientLibrary::GetOrCreateLocalClient(options));
|
ClientLibrary::GetOrCreateLocalClient(options));
|
||||||
|
|
||||||
|
bool gpu_platform = platform_name == "gpu";
|
||||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator;
|
std::unique_ptr<se::DeviceMemoryAllocator> allocator;
|
||||||
if (allocator_config.kind == AllocatorConfig::Kind::kBFC ||
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator;
|
||||||
(platform_name == "gpu" &&
|
if (gpu_platform) {
|
||||||
allocator_config.kind == AllocatorConfig::Kind::kDefault)) {
|
if (allocator_config.kind != AllocatorConfig::Kind::kPlatform) {
|
||||||
if (platform_name != "gpu") {
|
TF_ASSIGN_OR_RETURN(
|
||||||
return Unimplemented("BFCAllocator only available for GPU.");
|
allocator,
|
||||||
|
CreateBFCAllocator(platform, client, allocator_config.memory_fraction,
|
||||||
|
allocator_config.preallocate));
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
auto bfc_allocator,
|
tensorflow::SubAllocator* sub_allocator = new tensorflow::GpuHostAllocator(
|
||||||
CreateBFCAllocator(platform, client, allocator_config.memory_fraction,
|
client->backend().stream_executor(0).ValueOrDie(), /*numa_node=*/0,
|
||||||
allocator_config.preallocate));
|
/*alloc_visitors=*/{},
|
||||||
allocator = std::move(bfc_allocator);
|
/*free_visitors=*/{});
|
||||||
|
// TODO(phawkins): allow the user to tune this.
|
||||||
|
const int64 kGpuHostMemoryLimitBytes = 64 * (1LL << 30);
|
||||||
|
host_memory_allocator = absl::make_unique<tensorflow::BFCAllocator>(
|
||||||
|
sub_allocator, kGpuHostMemoryLimitBytes, /*allow_growth=*/true,
|
||||||
|
/*name=*/"xla_gpu_host_bfc");
|
||||||
|
|
||||||
|
} else if (allocator_config.kind == AllocatorConfig::Kind::kBFC) {
|
||||||
|
return Unimplemented("BFCAllocator only available for GPU.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
devices.reserve(client->device_count());
|
devices.reserve(client->device_count());
|
||||||
bool use_multiple_streams = (platform_name != "cpu");
|
bool synchronous_deallocation = platform_name == "cpu";
|
||||||
bool synchronous_deallocation = !use_multiple_streams;
|
|
||||||
for (int i = 0; i < client->device_count(); ++i) {
|
for (int i = 0; i < client->device_count(); ++i) {
|
||||||
se::StreamExecutor* executor =
|
se::StreamExecutor* executor =
|
||||||
client->backend().stream_executor(i).ValueOrDie();
|
client->backend().stream_executor(i).ValueOrDie();
|
||||||
devices.push_back(absl::make_unique<Device>(executor, use_multiple_streams,
|
devices.push_back(absl::make_unique<Device>(
|
||||||
synchronous_deallocation,
|
executor, synchronous_deallocation, asynchronous,
|
||||||
asynchronous));
|
/*allow_event_reuse=*/gpu_platform));
|
||||||
}
|
}
|
||||||
return std::make_shared<PyLocalClient>(platform_name, client,
|
return std::make_shared<PyLocalClient>(
|
||||||
std::move(devices),
|
platform_name, client, std::move(devices), std::move(allocator),
|
||||||
std::move(allocator), asynchronous);
|
std::move(host_memory_allocator));
|
||||||
}
|
}
|
||||||
|
|
||||||
PyLocalClient::PyLocalClient(
|
PyLocalClient::PyLocalClient(
|
||||||
std::string platform_name, LocalClient* client,
|
std::string platform_name, LocalClient* client,
|
||||||
std::vector<std::unique_ptr<Device>> devices,
|
std::vector<std::unique_ptr<Device>> devices,
|
||||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator, bool asynchronous)
|
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||||
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator)
|
||||||
: platform_name_(std::move(platform_name)),
|
: platform_name_(std::move(platform_name)),
|
||||||
client_(client),
|
client_(client),
|
||||||
devices_(std::move(devices)),
|
devices_(std::move(devices)),
|
||||||
owned_allocator_(std::move(allocator)),
|
owned_allocator_(std::move(allocator)),
|
||||||
|
host_memory_allocator_(std::move(host_memory_allocator)),
|
||||||
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
|
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
|
||||||
client->device_count()) {
|
client->device_count()) {
|
||||||
if (owned_allocator_ != nullptr) {
|
if (owned_allocator_ != nullptr) {
|
||||||
@ -303,56 +239,6 @@ StatusOr<pybind11::object> PyLocalClient::TransferFromOutfeed(
|
|||||||
return LiteralToPython(std::make_shared<Literal>(std::move(literal)));
|
return LiteralToPython(std::make_shared<Literal>(std::move(literal)));
|
||||||
}
|
}
|
||||||
|
|
||||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> TransferHostToDeviceAsync(
|
|
||||||
const PythonBufferTree& tree, int device_ordinal,
|
|
||||||
std::shared_ptr<PyLocalClient> client, const Device& device) {
|
|
||||||
se::DeviceMemoryAllocator* allocator = client->allocator();
|
|
||||||
TransferManager* transfer_manager =
|
|
||||||
client->client()->backend().transfer_manager();
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
Shape shape, transfer_manager->ChooseCompactLayoutForShape(tree.shape));
|
|
||||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
|
|
||||||
transfer_manager->AllocateScopedShapedBuffer(
|
|
||||||
shape, allocator, device_ordinal));
|
|
||||||
TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
|
|
||||||
device.host_to_device_stream(), buffer));
|
|
||||||
|
|
||||||
auto it = tree.leaves.begin();
|
|
||||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
|
||||||
ShapeUtil::GetLeafShapes(shape)) {
|
|
||||||
TF_RET_CHECK(it != tree.leaves.end());
|
|
||||||
ShapedBuffer leaf(
|
|
||||||
indexed_shape.shape,
|
|
||||||
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
|
|
||||||
client->client()->platform(), device_ordinal);
|
|
||||||
leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {});
|
|
||||||
if (device.use_multiple_streams() &&
|
|
||||||
!transfer_manager->CanShapedBufferBeAccessedNow(
|
|
||||||
device.host_to_device_stream()->parent(), leaf)) {
|
|
||||||
device.host_to_device_stream()->ThenWaitFor(device.compute_stream());
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDeviceAsync(
|
|
||||||
device.host_to_device_stream(), *it, leaf));
|
|
||||||
++it;
|
|
||||||
}
|
|
||||||
std::shared_ptr<BufferDefinitionEvent> definition_event;
|
|
||||||
if (device.use_multiple_streams()) {
|
|
||||||
TF_ASSIGN_OR_RETURN(definition_event,
|
|
||||||
BufferDefinitionEvent::Create(
|
|
||||||
device.host_to_device_stream()->parent()));
|
|
||||||
definition_event->RecordOnStream(device.host_to_device_stream());
|
|
||||||
}
|
|
||||||
std::shared_ptr<SharedDeviceBuffer> device_buffer =
|
|
||||||
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(buffer),
|
|
||||||
definition_event);
|
|
||||||
if (device.synchronous_deallocation()) {
|
|
||||||
device.ThenReleaseOnWorkerThread(device.host_to_device_stream(),
|
|
||||||
device_buffer);
|
|
||||||
}
|
|
||||||
return absl::make_unique<PyLocalBuffer>(shape, std::move(device_buffer),
|
|
||||||
std::move(client));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
|
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
|
||||||
const py::object& argument, std::shared_ptr<PyLocalClient> client,
|
const py::object& argument, std::shared_ptr<PyLocalClient> client,
|
||||||
@ -366,82 +252,83 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
|
|||||||
// remain live until the transfer is complete.
|
// remain live until the transfer is complete.
|
||||||
auto py_buffer_ref =
|
auto py_buffer_ref =
|
||||||
client->py_ref_manager().ManageReferences(absl::MakeSpan(tree.arrays));
|
client->py_ref_manager().ManageReferences(absl::MakeSpan(tree.arrays));
|
||||||
|
tree.arrays.clear();
|
||||||
|
|
||||||
// We are done manipulating Python objects; release the GIL.
|
// We are done manipulating Python objects; release the GIL.
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
VLOG(1) << "PyLocalBuffer::FromPython: shape: " << tree.shape.ToString()
|
VLOG(1) << "PyLocalBuffer::FromPython: shape: " << tree.shape.ToString()
|
||||||
<< " device ordinal: " << device_ordinal;
|
<< " device ordinal: " << device_ordinal;
|
||||||
|
|
||||||
const Device& device = client->device(device_ordinal);
|
Device* device = &client->device(device_ordinal);
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PyLocalBuffer> buffer,
|
TransferManager* transfer_manager =
|
||||||
TransferHostToDeviceAsync(tree, device_ordinal,
|
client->client()->backend().transfer_manager();
|
||||||
std::move(client), device));
|
se::DeviceMemoryAllocator* allocator = client->allocator();
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
Shape shape, transfer_manager->ChooseCompactLayoutForShape(tree.shape));
|
||||||
|
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
|
||||||
|
transfer_manager->AllocateScopedShapedBuffer(
|
||||||
|
shape, allocator, device_ordinal));
|
||||||
|
TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
|
||||||
|
device->host_to_device_stream(), buffer));
|
||||||
|
|
||||||
device.ThenRelease(device.host_to_device_stream(), std::move(py_buffer_ref));
|
std::vector<std::shared_ptr<void>> staging_buffers;
|
||||||
return buffer;
|
staging_buffers.reserve(tree.leaves.size());
|
||||||
}
|
auto it = tree.leaves.begin();
|
||||||
|
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||||
/*static */ StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
ShapeUtil::GetLeafShapes(shape)) {
|
||||||
PyLocalBuffer::FromPythonValues(
|
TF_RET_CHECK(it != tree.leaves.end());
|
||||||
const std::vector<std::pair<py::object, int>>& arguments,
|
ShapedBuffer leaf(
|
||||||
std::shared_ptr<PyLocalClient> client) {
|
indexed_shape.shape,
|
||||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPythonValues");
|
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
|
||||||
int num_arguments = static_cast<int>(arguments.size());
|
client->client()->platform(), device_ordinal);
|
||||||
std::vector<std::unique_ptr<PyLocalBuffer>> outputs(num_arguments);
|
leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {});
|
||||||
if (num_arguments == 0) {
|
if (!transfer_manager->CanShapedBufferBeAccessedNow(
|
||||||
return outputs;
|
device->host_to_device_stream()->parent(), leaf)) {
|
||||||
}
|
device->host_to_device_stream()->ThenWaitFor(device->compute_stream());
|
||||||
|
|
||||||
struct H2DTransfer {
|
|
||||||
PythonBufferTree tree;
|
|
||||||
StatusOr<std::unique_ptr<PyLocalBuffer>> buffer;
|
|
||||||
PythonRefManager::ManagedPyObjects py_buffer_refs;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<H2DTransfer> transfers(num_arguments);
|
|
||||||
for (int i = 0; i < num_arguments; ++i) {
|
|
||||||
TF_ASSIGN_OR_RETURN(transfers[i].tree,
|
|
||||||
GetPythonBufferTree(arguments[i].first));
|
|
||||||
transfers[i].py_buffer_refs = client->py_ref_manager().ManageReferences(
|
|
||||||
absl::MakeSpan(transfers[i].tree.arrays));
|
|
||||||
}
|
|
||||||
client->py_ref_manager().CollectGarbage();
|
|
||||||
// We are done manipulating Python objects; release the GIL.
|
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
|
|
||||||
auto transfer_h2d = [&](int i) -> StatusOr<std::unique_ptr<PyLocalBuffer>> {
|
|
||||||
int device_ordinal = arguments[i].second;
|
|
||||||
return TransferHostToDeviceAsync(transfers[i].tree, device_ordinal, client,
|
|
||||||
client->device(device_ordinal));
|
|
||||||
};
|
|
||||||
|
|
||||||
// We perform the transfers on a thread pool in case XLA needs to do any
|
|
||||||
// host-side preprocessing of the input data.
|
|
||||||
if (num_arguments == 1) {
|
|
||||||
transfers[0].buffer = transfer_h2d(0);
|
|
||||||
} else {
|
|
||||||
absl::BlockingCounter counter(num_arguments);
|
|
||||||
for (int i = 0; i < num_arguments; ++i) {
|
|
||||||
client->h2d_transfer_pool()->Schedule([&, i]() {
|
|
||||||
transfers[i].buffer = transfer_h2d(i);
|
|
||||||
counter.DecrementCount();
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
counter.Wait();
|
|
||||||
|
// If applicable on the backend, stage the transfer via host memory
|
||||||
|
// allocated via the host_memory_allocator. On GPU, this is pinned memory.
|
||||||
|
if (client->host_memory_allocator()) {
|
||||||
|
int64 size = it->size_bytes({});
|
||||||
|
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
||||||
|
tensorflow::Allocator::kAllocatorAlignment, size);
|
||||||
|
std::shared_ptr<void> staging_buffer(ptr, [client](void* ptr) {
|
||||||
|
client->host_memory_allocator()->DeallocateRaw(ptr);
|
||||||
|
});
|
||||||
|
std::memcpy(ptr, it->untyped_data({}), size);
|
||||||
|
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
|
||||||
|
it->shape());
|
||||||
|
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDeviceAsync(
|
||||||
|
device->host_to_device_stream(), literal, leaf));
|
||||||
|
staging_buffers.push_back(std::move(staging_buffer));
|
||||||
|
} else {
|
||||||
|
// Otherwise, just transfer the literal.
|
||||||
|
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDeviceAsync(
|
||||||
|
device->host_to_device_stream(), *it, leaf));
|
||||||
|
}
|
||||||
|
++it;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release our references once the transfers have completed.
|
auto definition_event = std::make_shared<BufferDefinitionEvent>();
|
||||||
for (int i = 0; i < num_arguments; ++i) {
|
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
|
||||||
int device_ordinal = arguments[i].second;
|
device->event_pool().ThenAllocateAndRecordEvent(
|
||||||
const Device& device = client->device(device_ordinal);
|
device->host_to_device_stream()));
|
||||||
device.ThenRelease(device.host_to_device_stream(),
|
definition_event->SetDefinitionEvent(std::move(event),
|
||||||
std::move(transfers[i].py_buffer_refs));
|
device->host_to_device_stream());
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < num_arguments; ++i) {
|
std::shared_ptr<SharedDeviceBuffer> device_buffer =
|
||||||
TF_ASSIGN_OR_RETURN(outputs[i], std::move(transfers[i].buffer));
|
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(buffer),
|
||||||
|
definition_event);
|
||||||
|
if (device->synchronous_deallocation()) {
|
||||||
|
device->ThenRelease(device->host_to_device_stream(), device_buffer);
|
||||||
}
|
}
|
||||||
return outputs;
|
device->ThenRelease(
|
||||||
|
device->host_to_device_stream(),
|
||||||
|
std::make_pair(std::move(py_buffer_ref), std::move(staging_buffers)));
|
||||||
|
|
||||||
|
return absl::make_unique<PyLocalBuffer>(shape, std::move(device_buffer),
|
||||||
|
std::move(client));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple(
|
/* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple(
|
||||||
@ -465,13 +352,9 @@ PyLocalBuffer::FromPythonValues(
|
|||||||
se::DeviceMemoryAllocator* allocator = client->allocator();
|
se::DeviceMemoryAllocator* allocator = client->allocator();
|
||||||
TransferManager* transfer_manager =
|
TransferManager* transfer_manager =
|
||||||
client->client()->backend().transfer_manager();
|
client->client()->backend().transfer_manager();
|
||||||
const Device& device = client->device(device_ordinal);
|
Device& device = client->device(device_ordinal);
|
||||||
std::shared_ptr<BufferDefinitionEvent> definition_event;
|
|
||||||
if (device.use_multiple_streams()) {
|
auto definition_event = std::make_shared<BufferDefinitionEvent>();
|
||||||
TF_ASSIGN_OR_RETURN(definition_event,
|
|
||||||
BufferDefinitionEvent::Create(
|
|
||||||
device.host_to_device_stream()->parent()));
|
|
||||||
}
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::shared_ptr<SharedDeviceBuffer> tuple_buffer,
|
std::shared_ptr<SharedDeviceBuffer> tuple_buffer,
|
||||||
SharedDeviceBuffer::MakeTuple(device_buffers, transfer_manager, allocator,
|
SharedDeviceBuffer::MakeTuple(device_buffers, transfer_manager, allocator,
|
||||||
@ -482,21 +365,22 @@ PyLocalBuffer::FromPythonValues(
|
|||||||
// TODO(phawkins): extend TransferManager so we do not need to form a full
|
// TODO(phawkins): extend TransferManager so we do not need to form a full
|
||||||
// ShapedBuffer just to write the root tuple index table.
|
// ShapedBuffer just to write the root tuple index table.
|
||||||
TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer->AsShapedBuffer());
|
TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer->AsShapedBuffer());
|
||||||
if (device.use_multiple_streams() &&
|
if (!transfer_manager->CanShapedBufferBeAccessedNow(
|
||||||
!transfer_manager->CanShapedBufferBeAccessedNow(
|
|
||||||
device.host_to_device_stream()->parent(), shaped_buffer)) {
|
device.host_to_device_stream()->parent(), shaped_buffer)) {
|
||||||
// Wait for the compute stream so that memory allocations are synchronized.
|
// Wait for the compute stream so that memory allocations are synchronized.
|
||||||
device.host_to_device_stream()->ThenWaitFor(device.compute_stream());
|
device.host_to_device_stream()->ThenWaitFor(device.compute_stream());
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
|
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
|
||||||
device.host_to_device_stream(), shaped_buffer));
|
device.host_to_device_stream(), shaped_buffer));
|
||||||
if (definition_event) {
|
|
||||||
definition_event->RecordOnStream(device.host_to_device_stream());
|
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
|
||||||
}
|
device.event_pool().ThenAllocateAndRecordEvent(
|
||||||
|
device.host_to_device_stream()));
|
||||||
|
definition_event->SetDefinitionEvent(std::move(event),
|
||||||
|
device.host_to_device_stream());
|
||||||
|
|
||||||
if (device.synchronous_deallocation()) {
|
if (device.synchronous_deallocation()) {
|
||||||
device.ThenReleaseOnWorkerThread(device.host_to_device_stream(),
|
device.ThenRelease(device.host_to_device_stream(), std::move(tuple_buffer));
|
||||||
std::move(tuple_buffer));
|
|
||||||
}
|
}
|
||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
@ -629,8 +513,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
|
|||||||
ScopedShapedBuffer dst_buffer,
|
ScopedShapedBuffer dst_buffer,
|
||||||
transfer_manager->AllocateScopedShapedBuffer(
|
transfer_manager->AllocateScopedShapedBuffer(
|
||||||
on_host_shape_, client_->allocator(), dst_device_ordinal));
|
on_host_shape_, client_->allocator(), dst_device_ordinal));
|
||||||
if (dst_device.use_multiple_streams() &&
|
if (!transfer_manager->CanShapedBufferBeAccessedNow(
|
||||||
!transfer_manager->CanShapedBufferBeAccessedNow(
|
|
||||||
dst_device.compute_stream()->parent(), dst_buffer)) {
|
dst_device.compute_stream()->parent(), dst_buffer)) {
|
||||||
src_device_to_device_stream->ThenWaitFor(dst_device.compute_stream());
|
src_device_to_device_stream->ThenWaitFor(dst_device.compute_stream());
|
||||||
}
|
}
|
||||||
@ -652,6 +535,10 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
|
|||||||
output_buffer));
|
output_buffer));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We hold on to the `src_device_buffer` until the transfer is finished.
|
||||||
|
src_device.ThenRelease(src_device_to_device_stream,
|
||||||
|
std::move(src_device_buffer));
|
||||||
|
|
||||||
// Write new tuple buffers. The destination buffers have different addresses,
|
// Write new tuple buffers. The destination buffers have different addresses,
|
||||||
// so we must construct tuple buffers from scratch instead of copying them.
|
// so we must construct tuple buffers from scratch instead of copying them.
|
||||||
if (dst_buffer.on_device_shape().IsTuple()) {
|
if (dst_buffer.on_device_shape().IsTuple()) {
|
||||||
@ -665,13 +552,12 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
|
|||||||
dst_device.host_to_device_stream());
|
dst_device.host_to_device_stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<BufferDefinitionEvent> definition_event;
|
auto definition_event = std::make_shared<BufferDefinitionEvent>();
|
||||||
if (dst_device.use_multiple_streams()) {
|
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
|
||||||
TF_ASSIGN_OR_RETURN(
|
src_device.event_pool().ThenAllocateAndRecordEvent(
|
||||||
definition_event,
|
src_device_to_device_stream));
|
||||||
BufferDefinitionEvent::Create(src_device_to_device_stream->parent()));
|
definition_event->SetDefinitionEvent(std::move(event),
|
||||||
definition_event->RecordOnStream(src_device_to_device_stream);
|
src_device_to_device_stream);
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<SharedDeviceBuffer> dst_device_buffer =
|
std::shared_ptr<SharedDeviceBuffer> dst_device_buffer =
|
||||||
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(dst_buffer),
|
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(dst_buffer),
|
||||||
@ -756,21 +642,21 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
|
|||||||
<< " buffer: " << argument_buffers.back().ToString();
|
<< " buffer: " << argument_buffers.back().ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
const Device& device = client_->device(device_ordinal);
|
Device* device = &client_->device(device_ordinal);
|
||||||
// The choice of where we wait in "synchronous" mode is arbitrary; the reason
|
// The choice of where we wait is arbitrary; the reason for the wait is pacing
|
||||||
// for the wait is pacing to avoid problems such as memory fragmentation, not
|
// to avoid problems such as memory fragmentation and running ahead too far,
|
||||||
// for correctness.
|
// not for correctness. Placing it before the executable launch allows the
|
||||||
if (!device.asynchronous()) {
|
// inputs for the next executable to be fetched even if the launch is delayed.
|
||||||
TF_RETURN_IF_ERROR(device.compute_stream()->BlockHostUntilDone());
|
auto compute_reservation = std::make_shared<Semaphore::ScopedReservation>(
|
||||||
}
|
device->compute_semaphore().ScopedAcquire(1));
|
||||||
|
|
||||||
for (BufferDefinitionEvent* event : events) {
|
for (BufferDefinitionEvent* event : events) {
|
||||||
event->WaitForEventOnStream(device.compute_stream());
|
event->WaitForEventOnStream(device->compute_stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
ExecutableRunOptions options;
|
ExecutableRunOptions options;
|
||||||
options.set_stream(device.compute_stream());
|
options.set_stream(device->compute_stream());
|
||||||
options.set_host_to_device_stream(device.host_to_device_stream());
|
options.set_host_to_device_stream(device->host_to_device_stream());
|
||||||
options.set_allocator(client_->allocator());
|
options.set_allocator(client_->allocator());
|
||||||
options.set_intra_op_thread_pool(
|
options.set_intra_op_thread_pool(
|
||||||
client_->client()->backend().eigen_intra_op_thread_pool_device());
|
client_->client()->backend().eigen_intra_op_thread_pool_device());
|
||||||
@ -787,24 +673,25 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
|
|||||||
return result_buffer.status();
|
return result_buffer.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<BufferDefinitionEvent> definition_event;
|
auto definition_event = std::make_shared<BufferDefinitionEvent>();
|
||||||
if (device.use_multiple_streams()) {
|
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
|
||||||
TF_ASSIGN_OR_RETURN(
|
device->event_pool().ThenAllocateAndRecordEvent(
|
||||||
definition_event,
|
device->compute_stream()));
|
||||||
BufferDefinitionEvent::Create(device.compute_stream()->parent()));
|
definition_event->SetDefinitionEvent(std::move(event),
|
||||||
definition_event->RecordOnStream(device.compute_stream());
|
device->compute_stream());
|
||||||
}
|
|
||||||
Shape on_host_shape = result_buffer.ValueOrDie().on_host_shape();
|
Shape on_host_shape = result_buffer.ValueOrDie().on_host_shape();
|
||||||
std::shared_ptr<SharedDeviceBuffer> out_buffer =
|
std::shared_ptr<SharedDeviceBuffer> out_buffer =
|
||||||
SharedDeviceBuffer::FromScopedShapedBuffer(
|
SharedDeviceBuffer::FromScopedShapedBuffer(
|
||||||
std::move(result_buffer.ValueOrDie()), definition_event);
|
std::move(result_buffer.ValueOrDie()), definition_event);
|
||||||
|
|
||||||
if (device.synchronous_deallocation()) {
|
if (device->synchronous_deallocation()) {
|
||||||
device_buffers.push_back(out_buffer);
|
device_buffers.push_back(out_buffer);
|
||||||
device.ThenReleaseOnWorkerThread(device.compute_stream(),
|
device->ThenRelease(device->compute_stream(), std::move(device_buffers));
|
||||||
std::move(device_buffers));
|
|
||||||
}
|
}
|
||||||
device.ThenReleaseOnWorkerThread(device.compute_stream(), executable_);
|
|
||||||
|
device->ThenRelease(device->compute_stream(),
|
||||||
|
std::make_pair(executable_, compute_reservation));
|
||||||
return absl::make_unique<PyLocalBuffer>(on_host_shape, std::move(out_buffer),
|
return absl::make_unique<PyLocalBuffer>(on_host_shape, std::move(out_buffer),
|
||||||
client_);
|
client_);
|
||||||
}
|
}
|
||||||
@ -853,7 +740,7 @@ PyLocalExecutable::ExecutePerReplica(
|
|||||||
for (int replica = 0; replica < num_replicas(); ++replica) {
|
for (int replica = 0; replica < num_replicas(); ++replica) {
|
||||||
const int device_ordinal = device_assignment_(replica, 0);
|
const int device_ordinal = device_assignment_(replica, 0);
|
||||||
const Device& device = client_->device(device_ordinal);
|
const Device& device = client_->device(device_ordinal);
|
||||||
device.worker_thread()->Schedule([&, replica] {
|
device.execute_thread()->Schedule([&, replica] {
|
||||||
results[replica] =
|
results[replica] =
|
||||||
ExecuteHelper(argument_handles[replica], replica, run_id);
|
ExecuteHelper(argument_handles[replica], replica, run_id);
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
|
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
|
||||||
|
|
||||||
#include <deque>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -27,137 +27,19 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/python/device.h"
|
||||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||||
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
|
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/python/worker_thread.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/shape.h"
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
#include "tensorflow/compiler/xla/status.h"
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// Registers a 'fn_capsule' as a CPU custom call target.
|
|
||||||
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
|
|
||||||
// "xla._CPU_CUSTOM_CALL_TARGET".
|
|
||||||
Status RegisterCpuCustomCallTarget(const std::string& fn_name,
|
|
||||||
pybind11::capsule capsule);
|
|
||||||
|
|
||||||
// Class that encapsulates state relating to a device (e.g., a GPU) on which we
|
|
||||||
// can perform computation and transfers.
|
|
||||||
class Device {
|
|
||||||
public:
|
|
||||||
// If use_multiple_streams is true, we allocate separate streams for compute
|
|
||||||
// and transfers. If it is false, we share a single stream for compute and
|
|
||||||
// transfers. The CPU device does not support multiple streams, and this is
|
|
||||||
// a workaround until it does.
|
|
||||||
//
|
|
||||||
// If synchronous_deallocation is true, the host must not free buffers until
|
|
||||||
// compute/transfers that use those buffers have completed. For example, this
|
|
||||||
// typically is the case for the "platform" where compute/transfers are
|
|
||||||
// operations that take place on another thread.
|
|
||||||
//
|
|
||||||
// If asynchronous is false, the host will synchronize to the device after
|
|
||||||
// each execution or transfer. This is intended for debugging only.
|
|
||||||
Device(se::StreamExecutor* executor, bool use_multiple_streams,
|
|
||||||
bool synchronous_deallocation, bool asynchronous);
|
|
||||||
virtual ~Device();
|
|
||||||
|
|
||||||
bool use_multiple_streams() const { return use_multiple_streams_; }
|
|
||||||
bool synchronous_deallocation() const { return synchronous_deallocation_; }
|
|
||||||
bool asynchronous() const { return asynchronous_; }
|
|
||||||
se::Stream* compute_stream() const { return compute_stream_.get(); }
|
|
||||||
se::Stream* host_to_device_stream() const {
|
|
||||||
return host_to_device_stream_.get();
|
|
||||||
}
|
|
||||||
se::Stream* device_to_host_stream() const {
|
|
||||||
return device_to_host_stream_.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a device to device stream. Allocates streams in a round-robin
|
|
||||||
// fashion amongst the available streams.
|
|
||||||
se::Stream* GetDeviceToDeviceStream();
|
|
||||||
|
|
||||||
// Enqueues a copy of `src_buffer` to `dst_buffer` onto `src_stream`.
|
|
||||||
virtual Status ThenMemcpyDeviceToDevice(se::Stream* src_stream,
|
|
||||||
se::Stream* dst_stream,
|
|
||||||
se::DeviceMemoryBase src_buffer,
|
|
||||||
se::DeviceMemoryBase dst_buffer);
|
|
||||||
|
|
||||||
// A worker thread, used for replicated computation launches and callbacks.
|
|
||||||
WorkerThread* worker_thread() const { return worker_thread_.get(); }
|
|
||||||
|
|
||||||
// Enqueues a host callback on 'stream', to be executed by worker_thread_.
|
|
||||||
// ThenDoHostCallback is often constrained in what it can do, in particular,
|
|
||||||
// on GPU the callback runs on a thread belonging to the GPU runtime and
|
|
||||||
// cannot perform GPU operations itself.
|
|
||||||
void ThenExecuteOnWorkerThread(se::Stream* stream,
|
|
||||||
std::function<void()> callback) const;
|
|
||||||
|
|
||||||
// Helper for releasing values from a callback at the tail of a stream.
|
|
||||||
// This is only permitted if object's destructor will not free any device
|
|
||||||
// objects, since the callback may be called from a device thread pool on
|
|
||||||
// GPU.
|
|
||||||
template <typename T>
|
|
||||||
void ThenRelease(se::Stream* stream, T object) const {
|
|
||||||
if (callback_stream_.get() != stream) {
|
|
||||||
callback_stream_->ThenWaitFor(stream);
|
|
||||||
}
|
|
||||||
callback_stream_->ThenDoHostCallback(
|
|
||||||
std::bind([](T& object) { /* releases object */ }, std::move(object)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helpers for releasing values on a worker thread at the tail of a stream on
|
|
||||||
// a worker thread.
|
|
||||||
template <typename T>
|
|
||||||
void ThenReleaseOnWorkerThread(se::Stream* stream,
|
|
||||||
std::shared_ptr<T> object) const {
|
|
||||||
// We use a non-smart pointer here because we want to ensure that the worker
|
|
||||||
// thread is the only callee of the shared_ptr destructor, and if we passed
|
|
||||||
// object by lambda capture we have a race where the worker thread might
|
|
||||||
// run and release its reference first.
|
|
||||||
auto* ref = new std::shared_ptr<T>(std::move(object));
|
|
||||||
if (callback_stream_.get() != stream) {
|
|
||||||
callback_stream_->ThenWaitFor(stream);
|
|
||||||
}
|
|
||||||
ThenExecuteOnWorkerThread(callback_stream_.get(), [ref]() { delete ref; });
|
|
||||||
}
|
|
||||||
template <typename T>
|
|
||||||
void ThenReleaseOnWorkerThread(se::Stream* stream,
|
|
||||||
std::vector<std::shared_ptr<T>> object) const {
|
|
||||||
auto* ref = new std::vector<std::shared_ptr<T>>(std::move(object));
|
|
||||||
if (callback_stream_.get() != stream) {
|
|
||||||
callback_stream_->ThenWaitFor(stream);
|
|
||||||
}
|
|
||||||
ThenExecuteOnWorkerThread(callback_stream_.get(), [ref]() { delete ref; });
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
Status SynchronizeAllActivity();
|
|
||||||
|
|
||||||
bool use_multiple_streams_;
|
|
||||||
bool synchronous_deallocation_;
|
|
||||||
bool asynchronous_;
|
|
||||||
std::shared_ptr<se::Stream> compute_stream_;
|
|
||||||
std::shared_ptr<se::Stream> host_to_device_stream_;
|
|
||||||
std::shared_ptr<se::Stream> device_to_host_stream_;
|
|
||||||
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_;
|
|
||||||
|
|
||||||
// Number of device-to-device streams to create in the multistream case.
|
|
||||||
static constexpr int kNumDeviceToDeviceStreams = 4;
|
|
||||||
|
|
||||||
absl::Mutex mu_;
|
|
||||||
int next_device_to_device_stream_ GUARDED_BY(mu_) = 0;
|
|
||||||
|
|
||||||
// Callback stream is used for running short host-side callbacks after device
|
|
||||||
// side events, without preventing the device-side stream from doing useful
|
|
||||||
// work.
|
|
||||||
std::shared_ptr<se::Stream> callback_stream_;
|
|
||||||
|
|
||||||
std::unique_ptr<WorkerThread> worker_thread_;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct AllocatorConfig {
|
struct AllocatorConfig {
|
||||||
enum class Kind {
|
enum class Kind {
|
||||||
kDefault, // Client picks the best option for the platform.
|
kDefault, // Client picks the best option for the platform.
|
||||||
@ -188,10 +70,11 @@ class PyLocalClient {
|
|||||||
bool asynchronous, const AllocatorConfig& allocator_config);
|
bool asynchronous, const AllocatorConfig& allocator_config);
|
||||||
|
|
||||||
// `allocator` may null, in which case the platform default allocator is used.
|
// `allocator` may null, in which case the platform default allocator is used.
|
||||||
explicit PyLocalClient(std::string platform_name, LocalClient* client,
|
explicit PyLocalClient(
|
||||||
std::vector<std::unique_ptr<Device>> devices,
|
std::string platform_name, LocalClient* client,
|
||||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
std::vector<std::unique_ptr<Device>> devices,
|
||||||
bool asynchronous);
|
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||||
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator);
|
||||||
virtual ~PyLocalClient() = default;
|
virtual ~PyLocalClient() = default;
|
||||||
|
|
||||||
Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal);
|
Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal);
|
||||||
@ -204,6 +87,9 @@ class PyLocalClient {
|
|||||||
}
|
}
|
||||||
LocalClient* client() const { return client_; }
|
LocalClient* client() const { return client_; }
|
||||||
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
||||||
|
tensorflow::Allocator* host_memory_allocator() const {
|
||||||
|
return host_memory_allocator_.get();
|
||||||
|
}
|
||||||
|
|
||||||
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
||||||
return &h2d_transfer_pool_;
|
return &h2d_transfer_pool_;
|
||||||
@ -225,6 +111,11 @@ class PyLocalClient {
|
|||||||
se::DeviceMemoryAllocator* allocator_;
|
se::DeviceMemoryAllocator* allocator_;
|
||||||
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
|
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
|
||||||
|
|
||||||
|
// Allocator to be used for staging memory transfers to devices. Optional;
|
||||||
|
// only used on GPU where it is more efficient to copy buffers to and from the
|
||||||
|
// device via a staging area of pinned memory.
|
||||||
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
||||||
|
|
||||||
tensorflow::thread::ThreadPool h2d_transfer_pool_;
|
tensorflow::thread::ThreadPool h2d_transfer_pool_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -241,12 +132,6 @@ class PyLocalBuffer {
|
|||||||
const pybind11::object& argument, std::shared_ptr<PyLocalClient> client,
|
const pybind11::object& argument, std::shared_ptr<PyLocalClient> client,
|
||||||
int device_ordinal);
|
int device_ordinal);
|
||||||
|
|
||||||
// Converts multiple (python object, device ordinal) pairs into
|
|
||||||
// PyLocalBuffers in parallel.
|
|
||||||
static StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> FromPythonValues(
|
|
||||||
const std::vector<std::pair<pybind11::object, int>>& argument,
|
|
||||||
std::shared_ptr<PyLocalClient> client);
|
|
||||||
|
|
||||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
|
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
|
||||||
const std::vector<PyLocalBuffer*> buffers,
|
const std::vector<PyLocalBuffer*> buffers,
|
||||||
std::shared_ptr<PyLocalClient> client, int device_ordinal);
|
std::shared_ptr<PyLocalClient> client, int device_ordinal);
|
||||||
|
@ -37,9 +37,9 @@ PythonRefManager::ManagedPyObjects::~ManagedPyObjects() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PythonRefManager::ManagedPyObjects PythonRefManager::ManageReferences(
|
std::shared_ptr<PythonRefManager::ManagedPyObjects>
|
||||||
absl::Span<py::object> objects) {
|
PythonRefManager::ManageReferences(absl::Span<py::object> objects) {
|
||||||
return ManagedPyObjects(this, objects);
|
return std::make_shared<ManagedPyObjects>(this, objects);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PythonRefManager::CollectGarbage() {
|
void PythonRefManager::CollectGarbage() {
|
||||||
|
@ -48,9 +48,9 @@ class PythonRefManager {
|
|||||||
|
|
||||||
~ManagedPyObjects();
|
~ManagedPyObjects();
|
||||||
|
|
||||||
ManagedPyObjects(const ManagedPyObjects& other) = default;
|
ManagedPyObjects(const ManagedPyObjects& other) = delete;
|
||||||
ManagedPyObjects(ManagedPyObjects&& other) = default;
|
ManagedPyObjects(ManagedPyObjects&& other) = default;
|
||||||
ManagedPyObjects& operator=(const ManagedPyObjects& other) = default;
|
ManagedPyObjects& operator=(const ManagedPyObjects& other) = delete;
|
||||||
ManagedPyObjects& operator=(ManagedPyObjects&& other) = default;
|
ManagedPyObjects& operator=(ManagedPyObjects&& other) = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -61,7 +61,8 @@ class PythonRefManager {
|
|||||||
// Creates a managed std::shared_ptr to an object. When the shared_ptr is
|
// Creates a managed std::shared_ptr to an object. When the shared_ptr is
|
||||||
// destroyed, the reference to 'object' will be added to python_garbage_,
|
// destroyed, the reference to 'object' will be added to python_garbage_,
|
||||||
// and collected next time CollectGarbage() is called.
|
// and collected next time CollectGarbage() is called.
|
||||||
ManagedPyObjects ManageReferences(absl::Span<pybind11::object> objects);
|
std::shared_ptr<ManagedPyObjects> ManageReferences(
|
||||||
|
absl::Span<pybind11::object> objects);
|
||||||
|
|
||||||
// Releases the contents of python_garbage_. Requires that the GIL is held.
|
// Releases the contents of python_garbage_. Requires that the GIL is held.
|
||||||
// The client calls this method during API entry points where the GIL is held
|
// The client calls this method during API entry points where the GIL is held
|
||||||
|
74
tensorflow/compiler/xla/python/semaphore.cc
Normal file
74
tensorflow/compiler/xla/python/semaphore.cc
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/python/semaphore.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
Semaphore::Semaphore(int64 capacity) : value_(capacity) {
|
||||||
|
CHECK_GE(capacity, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Semaphore::CanAcquire(CanAcquireArgs* args) {
|
||||||
|
return args->semaphore->value_ >= args->amount;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Semaphore::Acquire(int64 amount) {
|
||||||
|
CHECK_GE(amount, 0);
|
||||||
|
|
||||||
|
CanAcquireArgs args;
|
||||||
|
args.semaphore = this;
|
||||||
|
args.amount = amount;
|
||||||
|
|
||||||
|
mu_.LockWhen(absl::Condition(&CanAcquire, &args));
|
||||||
|
value_ -= amount;
|
||||||
|
mu_.Unlock();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Semaphore::Release(int64 amount) {
|
||||||
|
CHECK_GE(amount, 0);
|
||||||
|
absl::MutexLock lock(&mu_);
|
||||||
|
value_ += amount;
|
||||||
|
}
|
||||||
|
|
||||||
|
Semaphore::ScopedReservation::~ScopedReservation() {
|
||||||
|
if (semaphore_) {
|
||||||
|
semaphore_->Release(amount_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Semaphore::ScopedReservation::ScopedReservation(
|
||||||
|
ScopedReservation&& other) noexcept {
|
||||||
|
semaphore_ = other.semaphore_;
|
||||||
|
amount_ = other.amount_;
|
||||||
|
other.semaphore_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Semaphore::ScopedReservation& Semaphore::ScopedReservation::operator=(
|
||||||
|
ScopedReservation&& other) noexcept {
|
||||||
|
semaphore_ = other.semaphore_;
|
||||||
|
amount_ = other.amount_;
|
||||||
|
other.semaphore_ = nullptr;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
Semaphore::ScopedReservation Semaphore::ScopedAcquire(int64 amount) {
|
||||||
|
Acquire(amount);
|
||||||
|
return ScopedReservation(this, amount);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
67
tensorflow/compiler/xla/python/semaphore.h
Normal file
67
tensorflow/compiler/xla/python/semaphore.h
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
class Semaphore {
|
||||||
|
public:
|
||||||
|
explicit Semaphore(int64 capacity);
|
||||||
|
|
||||||
|
// Acquires `amount` units. Blocks until `amount` units are available.
|
||||||
|
void Acquire(int64 amount);
|
||||||
|
|
||||||
|
// Returns `amount` units to the semaphore.
|
||||||
|
void Release(int64 amount);
|
||||||
|
|
||||||
|
class ScopedReservation {
|
||||||
|
public:
|
||||||
|
ScopedReservation(Semaphore* semaphore, int64 amount)
|
||||||
|
: semaphore_(semaphore), amount_(amount) {}
|
||||||
|
~ScopedReservation();
|
||||||
|
|
||||||
|
ScopedReservation(const ScopedReservation&) = delete;
|
||||||
|
ScopedReservation(ScopedReservation&& other) noexcept;
|
||||||
|
ScopedReservation& operator=(const ScopedReservation&) = delete;
|
||||||
|
ScopedReservation& operator=(ScopedReservation&& other) noexcept;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Semaphore* semaphore_;
|
||||||
|
int64 amount_;
|
||||||
|
};
|
||||||
|
// RAII version of Acquire. Releases the reservation when the
|
||||||
|
// ScopedReservation is destroyed.
|
||||||
|
ScopedReservation ScopedAcquire(int64 amount);
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct CanAcquireArgs {
|
||||||
|
Semaphore* semaphore;
|
||||||
|
int64 amount;
|
||||||
|
};
|
||||||
|
static bool CanAcquire(CanAcquireArgs* args)
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(args->semaphore->mu_);
|
||||||
|
|
||||||
|
absl::Mutex mu_;
|
||||||
|
int64 value_ GUARDED_BY(mu_);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_
|
74
tensorflow/compiler/xla/python/semaphore_test.cc
Normal file
74
tensorflow/compiler/xla/python/semaphore_test.cc
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/python/semaphore.h"
|
||||||
|
|
||||||
|
#include "absl/synchronization/notification.h"
|
||||||
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(SemaphoreTest, UnthreadedTests) {
|
||||||
|
Semaphore semaphore(2);
|
||||||
|
semaphore.Acquire(1);
|
||||||
|
semaphore.Release(1);
|
||||||
|
|
||||||
|
semaphore.Acquire(2);
|
||||||
|
semaphore.Release(2);
|
||||||
|
|
||||||
|
semaphore.Acquire(1);
|
||||||
|
semaphore.Acquire(1);
|
||||||
|
semaphore.Release(1);
|
||||||
|
semaphore.Acquire(1);
|
||||||
|
semaphore.Release(1);
|
||||||
|
semaphore.Acquire(1);
|
||||||
|
semaphore.Release(2);
|
||||||
|
|
||||||
|
{
|
||||||
|
auto a = semaphore.ScopedAcquire(1);
|
||||||
|
{ auto b = semaphore.ScopedAcquire(1); }
|
||||||
|
{ auto c = semaphore.ScopedAcquire(1); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SemaphoreTest, ConcurrentTest) {
|
||||||
|
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2);
|
||||||
|
Semaphore semaphore(2);
|
||||||
|
semaphore.Acquire(1);
|
||||||
|
|
||||||
|
absl::Notification a_done;
|
||||||
|
pool.Schedule([&]() {
|
||||||
|
semaphore.Acquire(2);
|
||||||
|
semaphore.Release(2);
|
||||||
|
a_done.Notify();
|
||||||
|
});
|
||||||
|
|
||||||
|
absl::Notification b_done;
|
||||||
|
pool.Schedule([&]() {
|
||||||
|
semaphore.Acquire(1);
|
||||||
|
semaphore.Release(1);
|
||||||
|
b_done.Notify();
|
||||||
|
});
|
||||||
|
b_done.WaitForNotification();
|
||||||
|
EXPECT_FALSE(a_done.HasBeenNotified());
|
||||||
|
semaphore.Release(1);
|
||||||
|
a_done.WaitForNotification();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace xla
|
@ -21,21 +21,12 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
/*static*/ StatusOr<std::shared_ptr<BufferDefinitionEvent>>
|
void BufferDefinitionEvent::SetDefinitionEvent(EventPool::Handle event,
|
||||||
BufferDefinitionEvent::Create(se::StreamExecutor* executor) {
|
se::Stream* stream) {
|
||||||
auto event = std::make_shared<BufferDefinitionEvent>(executor);
|
|
||||||
TF_RET_CHECK(event->event_.Init())
|
|
||||||
<< "Buffer definition event initialization failed";
|
|
||||||
return event;
|
|
||||||
}
|
|
||||||
|
|
||||||
BufferDefinitionEvent::BufferDefinitionEvent(se::StreamExecutor* executor)
|
|
||||||
: event_(executor) {}
|
|
||||||
|
|
||||||
void BufferDefinitionEvent::RecordOnStream(se::Stream* stream) {
|
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
|
CHECK(!event_.event());
|
||||||
|
event_ = std::move(event);
|
||||||
CHECK(streams_defined_on_.empty());
|
CHECK(streams_defined_on_.empty());
|
||||||
stream->ThenRecordEvent(&event_);
|
|
||||||
streams_defined_on_.push_back(stream);
|
streams_defined_on_.push_back(stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,7 +41,7 @@ void BufferDefinitionEvent::WaitForEventOnStream(se::Stream* stream) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
stream->ThenWaitFor(&event_);
|
stream->ThenWaitFor(event_.event());
|
||||||
streams_defined_on_.push_back(stream);
|
streams_defined_on_.push_back(stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_
|
#define TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_
|
||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "tensorflow/compiler/xla/python/event_pool.h"
|
||||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||||
#include "tensorflow/compiler/xla/shape.h"
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
@ -50,14 +51,11 @@ namespace xla {
|
|||||||
// same stream causes no additional waiting.
|
// same stream causes no additional waiting.
|
||||||
class BufferDefinitionEvent {
|
class BufferDefinitionEvent {
|
||||||
public:
|
public:
|
||||||
// Creates a new definition event whose event has not yet been triggered.
|
BufferDefinitionEvent() = default;
|
||||||
static StatusOr<std::shared_ptr<BufferDefinitionEvent>> Create(
|
|
||||||
se::StreamExecutor* executor);
|
|
||||||
|
|
||||||
explicit BufferDefinitionEvent(se::StreamExecutor* executor);
|
// Sets the definition event of the buffer to 'event', which is recorded
|
||||||
|
// on 'stream'. Must be called at most once.
|
||||||
// Records the definition event on the tail of 'stream'.
|
void SetDefinitionEvent(EventPool::Handle event, se::Stream* stream);
|
||||||
void RecordOnStream(se::Stream* stream);
|
|
||||||
|
|
||||||
// Adds synchronization events to 'stream' that wait for this event to be
|
// Adds synchronization events to 'stream' that wait for this event to be
|
||||||
// defined on 'stream'. Does nothing if the event is already known to have
|
// defined on 'stream'. Does nothing if the event is already known to have
|
||||||
@ -68,7 +66,7 @@ class BufferDefinitionEvent {
|
|||||||
// An event that is triggered when the content of one or more buffers is
|
// An event that is triggered when the content of one or more buffers is
|
||||||
// ready. If this event is nullptr, it is assumed that the buffer's content is
|
// ready. If this event is nullptr, it is assumed that the buffer's content is
|
||||||
// always defined.
|
// always defined.
|
||||||
se::Event event_;
|
EventPool::Handle event_;
|
||||||
|
|
||||||
absl::Mutex mu_;
|
absl::Mutex mu_;
|
||||||
|
|
||||||
|
@ -13,10 +13,13 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/base/casts.h"
|
||||||
#include "absl/hash/hash.h"
|
#include "absl/hash/hash.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
@ -34,6 +37,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/python/local_client.h"
|
#include "tensorflow/compiler/xla/python/local_client.h"
|
||||||
#include "tensorflow/compiler/xla/python/types.h"
|
#include "tensorflow/compiler/xla/python/types.h"
|
||||||
#include "tensorflow/compiler/xla/python/xrt.h"
|
#include "tensorflow/compiler/xla/python/xrt.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
@ -104,6 +108,22 @@ StatusOr<std::string> GetComputationHloDotGraph(
|
|||||||
RenderedGraphFormat::kDot);
|
RenderedGraphFormat::kDot);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Registers a 'fn_capsule' as a CPU custom call target.
|
||||||
|
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
|
||||||
|
// "xla._CPU_CUSTOM_CALL_TARGET".
|
||||||
|
Status RegisterCpuCustomCallTarget(const std::string& fn_name,
|
||||||
|
py::capsule capsule) {
|
||||||
|
static const char* const kName = "xla._CPU_CUSTOM_CALL_TARGET";
|
||||||
|
if (absl::string_view(capsule.name()) != kName) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"Argument to RegisterCpuCustomCallTargetRegistry was not a "
|
||||||
|
"xla._CPU_CUSTOM_CALL_TARGET capsule.");
|
||||||
|
}
|
||||||
|
CustomCallTargetRegistry::Global()->Register(
|
||||||
|
fn_name, static_cast<void*>(capsule), "Host");
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
PYBIND11_MODULE(xla_extension, m) {
|
PYBIND11_MODULE(xla_extension, m) {
|
||||||
@ -297,7 +317,6 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
|
|
||||||
py::class_<PyLocalBuffer>(m, "PyLocalBuffer")
|
py::class_<PyLocalBuffer>(m, "PyLocalBuffer")
|
||||||
.def_static("from_python", &PyLocalBuffer::FromPython)
|
.def_static("from_python", &PyLocalBuffer::FromPython)
|
||||||
.def_static("from_python_values", &PyLocalBuffer::FromPythonValues)
|
|
||||||
.def_static("make_tuple", &PyLocalBuffer::MakeTuple)
|
.def_static("make_tuple", &PyLocalBuffer::MakeTuple)
|
||||||
.def("copy_to_device", &PyLocalBuffer::CopyToDevice)
|
.def("copy_to_device", &PyLocalBuffer::CopyToDevice)
|
||||||
.def("delete", &PyLocalBuffer::Delete)
|
.def("delete", &PyLocalBuffer::Delete)
|
||||||
@ -307,9 +326,22 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.def("to_py", &PyLocalBuffer::ToPython)
|
.def("to_py", &PyLocalBuffer::ToPython)
|
||||||
.def("shape", &PyLocalBuffer::on_host_shape)
|
.def("shape", &PyLocalBuffer::on_host_shape)
|
||||||
.def("device", &PyLocalBuffer::device_ordinal)
|
.def("device", &PyLocalBuffer::device_ordinal)
|
||||||
.def("is_deleted", [](const PyLocalBuffer& buffer) {
|
.def("is_deleted",
|
||||||
return buffer.DeviceBuffer() == nullptr;
|
[](const PyLocalBuffer& buffer) {
|
||||||
});
|
return buffer.DeviceBuffer() == nullptr;
|
||||||
|
})
|
||||||
|
.def("unsafe_buffer_pointer",
|
||||||
|
[](const PyLocalBuffer& buffer) -> StatusOr<std::uintptr_t> {
|
||||||
|
TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer,
|
||||||
|
buffer.AsShapedBuffer());
|
||||||
|
if (shaped_buffer.on_device_shape().IsTuple()) {
|
||||||
|
return Unimplemented(
|
||||||
|
"unsafe_buffer_pointer is not implemented for tuple "
|
||||||
|
"buffers.");
|
||||||
|
}
|
||||||
|
return absl::bit_cast<std::uintptr_t>(
|
||||||
|
shaped_buffer.root_buffer().opaque());
|
||||||
|
});
|
||||||
|
|
||||||
py::class_<PyLocalExecutable>(m, "LocalExecutable")
|
py::class_<PyLocalExecutable>(m, "LocalExecutable")
|
||||||
.def_static("Compile", &PyLocalExecutable::Compile,
|
.def_static("Compile", &PyLocalExecutable::Compile,
|
||||||
|
@ -98,10 +98,6 @@ class LocalBackend(Backend):
|
|||||||
def buffer_from_pyval(self, pyval, device=0):
|
def buffer_from_pyval(self, pyval, device=0):
|
||||||
return _xla.PyLocalBuffer.from_python(pyval, self.client, device)
|
return _xla.PyLocalBuffer.from_python(pyval, self.client, device)
|
||||||
|
|
||||||
def buffers_from_pyvals(self, pyvals_and_devices):
|
|
||||||
return _xla.PyLocalBuffer.from_python_values(pyvals_and_devices,
|
|
||||||
self.client)
|
|
||||||
|
|
||||||
def make_tuple(self, c_buffers, device_ordinal):
|
def make_tuple(self, c_buffers, device_ordinal):
|
||||||
return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device_ordinal)
|
return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device_ordinal)
|
||||||
|
|
||||||
|
@ -644,6 +644,7 @@ cc_library(
|
|||||||
hdrs = ["call_inliner.h"],
|
hdrs = ["call_inliner.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":call_graph",
|
":call_graph",
|
||||||
|
":hlo",
|
||||||
":hlo_dce",
|
":hlo_dce",
|
||||||
":hlo_pass",
|
":hlo_pass",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
@ -1095,50 +1096,6 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "buffer_liveness",
|
|
||||||
srcs = [
|
|
||||||
"buffer_liveness.cc",
|
|
||||||
],
|
|
||||||
hdrs = [
|
|
||||||
"buffer_liveness.h",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":hlo",
|
|
||||||
":hlo_ordering",
|
|
||||||
":logical_buffer",
|
|
||||||
":tuple_points_to_analysis",
|
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
|
||||||
"//tensorflow/compiler/xla:statusor",
|
|
||||||
"//tensorflow/compiler/xla:types",
|
|
||||||
"//tensorflow/compiler/xla:util",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/strings:str_format",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_cc_test(
|
|
||||||
name = "buffer_liveness_test",
|
|
||||||
srcs = ["buffer_liveness_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":buffer_liveness",
|
|
||||||
":hlo",
|
|
||||||
":hlo_dataflow_analysis",
|
|
||||||
":hlo_parser",
|
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
|
||||||
"//tensorflow/compiler/xla:types",
|
|
||||||
"//tensorflow/compiler/xla:util",
|
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
"@com_google_absl//absl/memory",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "buffer_assignment",
|
name = "buffer_assignment",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -1148,7 +1105,6 @@ cc_library(
|
|||||||
"buffer_assignment.h",
|
"buffer_assignment.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":buffer_liveness",
|
|
||||||
":buffer_value_containers",
|
":buffer_value_containers",
|
||||||
":heap_simulator",
|
":heap_simulator",
|
||||||
":hlo",
|
":hlo",
|
||||||
@ -2785,7 +2741,6 @@ cc_library(
|
|||||||
srcs = ["copy_insertion.cc"],
|
srcs = ["copy_insertion.cc"],
|
||||||
hdrs = ["copy_insertion.h"],
|
hdrs = ["copy_insertion.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":buffer_liveness",
|
|
||||||
":dump",
|
":dump",
|
||||||
":hlo",
|
":hlo",
|
||||||
":hlo_alias_analysis",
|
":hlo_alias_analysis",
|
||||||
@ -2907,7 +2862,6 @@ cc_library(
|
|||||||
srcs = ["hlo_rematerialization.cc"],
|
srcs = ["hlo_rematerialization.cc"],
|
||||||
hdrs = ["hlo_rematerialization.h"],
|
hdrs = ["hlo_rematerialization.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":buffer_liveness",
|
|
||||||
":buffer_value",
|
":buffer_value",
|
||||||
":call_graph",
|
":call_graph",
|
||||||
":flatten_call_graph",
|
":flatten_call_graph",
|
||||||
@ -3626,7 +3580,6 @@ cc_library(
|
|||||||
srcs = ["reduce_precision_insertion.cc"],
|
srcs = ["reduce_precision_insertion.cc"],
|
||||||
hdrs = ["reduce_precision_insertion.h"],
|
hdrs = ["reduce_precision_insertion.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":buffer_liveness",
|
|
||||||
":hlo",
|
":hlo",
|
||||||
":hlo_pass",
|
":hlo_pass",
|
||||||
":hlo_pass_pipeline",
|
":hlo_pass_pipeline",
|
||||||
@ -3946,7 +3899,6 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
@ -168,13 +168,8 @@ bool IsUnstridedSlice(const HloInstruction* hlo) {
|
|||||||
// algebraic expressions to simplified forms. Note: This only supports
|
// algebraic expressions to simplified forms. Note: This only supports
|
||||||
// simplifications that simply look at the operands of an instruction. For the
|
// simplifications that simply look at the operands of an instruction. For the
|
||||||
// more general case a worklist based approach would be needed.
|
// more general case a worklist based approach would be needed.
|
||||||
class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
|
||||||
public:
|
public:
|
||||||
// Default visitor action is to do nothing and return OK.
|
|
||||||
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status HandleAdd(HloInstruction* add) override;
|
Status HandleAdd(HloInstruction* add) override;
|
||||||
|
|
||||||
Status HandleAnd(HloInstruction* logical_and) override;
|
Status HandleAnd(HloInstruction* logical_and) override;
|
||||||
@ -250,9 +245,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
|||||||
|
|
||||||
Status HandleMap(HloInstruction* map) override;
|
Status HandleMap(HloInstruction* map) override;
|
||||||
|
|
||||||
// Returns whether algebraic simplification has occurred.
|
|
||||||
const bool changed() const { return changed_; }
|
|
||||||
|
|
||||||
// Runs the visitor on a computation.
|
// Runs the visitor on a computation.
|
||||||
static bool Run(HloComputation* computation,
|
static bool Run(HloComputation* computation,
|
||||||
const AlgebraicSimplifierOptions& options,
|
const AlgebraicSimplifierOptions& options,
|
||||||
@ -350,35 +342,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
|||||||
StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
||||||
HloInstruction* broadcast);
|
HloInstruction* broadcast);
|
||||||
|
|
||||||
// Replaces the existing HLO instruction old_instruction, with
|
|
||||||
// new_instruction, and marks the optimizer status as changed.
|
|
||||||
// Returns the Status representing the result of the replace operation.
|
|
||||||
Status ReplaceWithNewInstruction(
|
|
||||||
HloInstruction* old_instruction,
|
|
||||||
std::unique_ptr<HloInstruction> new_instruction) {
|
|
||||||
VLOG(3) << "Replacing instruction:";
|
|
||||||
VLOG(3) << " old: " << old_instruction->ToString();
|
|
||||||
VLOG(3) << " new: " << new_instruction->ToString();
|
|
||||||
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
|
|
||||||
old_instruction, std::move(new_instruction)));
|
|
||||||
changed_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replaces the existing HLO instruction old_instruction, with
|
|
||||||
// new_instruction, and marks the optimizer status as changed.
|
|
||||||
// Returns the Status representing the result of the replace operation.
|
|
||||||
Status ReplaceInstruction(HloInstruction* old_instruction,
|
|
||||||
HloInstruction* new_instruction) {
|
|
||||||
VLOG(3) << "Replacing instruction:";
|
|
||||||
VLOG(3) << " old: " << old_instruction->ToString();
|
|
||||||
VLOG(3) << " new: " << new_instruction->ToString();
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
computation_->ReplaceInstruction(old_instruction, new_instruction));
|
|
||||||
changed_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
|
StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
|
||||||
StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
|
StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
|
||||||
const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
|
const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
|
||||||
@ -445,7 +408,7 @@ bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
|
|||||||
AlgebraicSimplifier* simplifier) {
|
AlgebraicSimplifier* simplifier) {
|
||||||
AlgebraicSimplifierVisitor visitor(computation, options, simplifier);
|
AlgebraicSimplifierVisitor visitor(computation, options, simplifier);
|
||||||
TF_CHECK_OK(computation->Accept(&visitor));
|
TF_CHECK_OK(computation->Accept(&visitor));
|
||||||
return visitor.changed_;
|
return visitor.changed_ || visitor.changed();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
|
bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
|
||||||
@ -1723,6 +1686,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
|
Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
|
||||||
|
CHECK(computation_ == dot->parent());
|
||||||
HloInstruction *lhs, *rhs;
|
HloInstruction *lhs, *rhs;
|
||||||
CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
|
CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
|
||||||
if (options_.is_layout_sensitive()) {
|
if (options_.is_layout_sensitive()) {
|
||||||
@ -2660,6 +2624,11 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
|
|||||||
HloInstruction *a, *b;
|
HloInstruction *a, *b;
|
||||||
CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
|
CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
|
||||||
|
|
||||||
|
// (A % B) % B == A % B.
|
||||||
|
if (Match(a, m::Remainder(m::Op(), m::Op().Is(b)))) {
|
||||||
|
return ReplaceInstruction(remainder, a);
|
||||||
|
}
|
||||||
|
|
||||||
// A % B => A & (B - 1) if B is a power of 2.
|
// A % B => A & (B - 1) if B is a power of 2.
|
||||||
switch (remainder->shape().element_type()) {
|
switch (remainder->shape().element_type()) {
|
||||||
case S8:
|
case S8:
|
||||||
|
@ -5503,5 +5503,20 @@ TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIotaOverflow) {
|
|||||||
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, RepeatedRemainder) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
p = s32[1000] parameter(0)
|
||||||
|
q = s32[1000] parameter(1)
|
||||||
|
r = s32[1000] remainder(p, q)
|
||||||
|
ROOT rr = s32[1000] remainder(r, q)
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::Remainder(m::Parameter(), m::Parameter())));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -46,13 +46,8 @@ using absl::optional;
|
|||||||
|
|
||||||
// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
|
// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
|
||||||
// operations into smaller operations.
|
// operations into smaller operations.
|
||||||
class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
|
class BatchNormExpanderVisitor : public DfsHloRewriteVisitor {
|
||||||
public:
|
public:
|
||||||
// Default visitor action is to do nothing and return OK.
|
|
||||||
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
|
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
|
||||||
|
|
||||||
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
|
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
|
||||||
@ -63,9 +58,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
|
|||||||
static bool Run(HloComputation* computation, bool rewrite_training_op,
|
static bool Run(HloComputation* computation, bool rewrite_training_op,
|
||||||
bool rewrite_inference_op, bool rewrite_grad_op);
|
bool rewrite_inference_op, bool rewrite_grad_op);
|
||||||
|
|
||||||
// Returns whether any batch norm ops were rewritten.
|
|
||||||
const bool changed() const { return changed_; }
|
|
||||||
|
|
||||||
~BatchNormExpanderVisitor() override = default;
|
~BatchNormExpanderVisitor() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -133,28 +125,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
|
|||||||
elements_per_feature_u32);
|
elements_per_feature_u32);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replaces the existing HLO instruction old_instruction, with
|
|
||||||
// new_instruction, and marks the optimizer status as changed.
|
|
||||||
// Returns the Status representing the result of the replace operation.
|
|
||||||
Status ReplaceWithNewInstruction(
|
|
||||||
HloInstruction* old_instruction,
|
|
||||||
std::unique_ptr<HloInstruction> new_instruction) {
|
|
||||||
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
|
|
||||||
old_instruction, std::move(new_instruction)));
|
|
||||||
changed_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replaces the existing HLO instruction old_instruction, with
|
|
||||||
// new_instruction, and marks the optimizer status as changed.
|
|
||||||
// Returns the Status representing the result of the replace operation.
|
|
||||||
Status ReplaceInstruction(HloInstruction* old_instruction,
|
|
||||||
HloInstruction* new_instruction) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
computation_->ReplaceInstruction(old_instruction, new_instruction));
|
|
||||||
changed_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
// Current HloComputation instance the BatchNormExpander is
|
// Current HloComputation instance the BatchNormExpander is
|
||||||
// traversing.
|
// traversing.
|
||||||
HloComputation* computation_;
|
HloComputation* computation_;
|
||||||
@ -162,9 +132,6 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
|
|||||||
bool rewrite_training_op_;
|
bool rewrite_training_op_;
|
||||||
bool rewrite_inference_op_;
|
bool rewrite_inference_op_;
|
||||||
bool rewrite_grad_op_;
|
bool rewrite_grad_op_;
|
||||||
|
|
||||||
// Whether rewrite has occurred.
|
|
||||||
bool changed_ = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -179,7 +146,7 @@ bool BatchNormExpanderVisitor::Run(HloComputation* computation,
|
|||||||
/*rewrite_inference_op=*/rewrite_inference_op,
|
/*rewrite_inference_op=*/rewrite_inference_op,
|
||||||
/*rewrite_grad_op=*/rewrite_grad_op);
|
/*rewrite_grad_op=*/rewrite_grad_op);
|
||||||
TF_CHECK_OK(computation->Accept(&visitor));
|
TF_CHECK_OK(computation->Accept(&visitor));
|
||||||
return visitor.changed_;
|
return visitor.changed();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BatchNormExpanderVisitor::HandleBatchNormTraining(
|
Status BatchNormExpanderVisitor::HandleBatchNormTraining(
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
|
#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
|
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
@ -656,7 +656,7 @@ Status BufferAssignment::ComputeSummaryStats() {
|
|||||||
for (const auto& computation : module_->computations()) {
|
for (const auto& computation : module_->computations()) {
|
||||||
if (!computation->IsFusionComputation()) {
|
if (!computation->IsFusionComputation()) {
|
||||||
const HloInstructionSequence* sequence =
|
const HloInstructionSequence* sequence =
|
||||||
liveness_->hlo_ordering().SequentialOrder(*computation);
|
hlo_ordering().SequentialOrder(*computation);
|
||||||
if (sequence == nullptr) {
|
if (sequence == nullptr) {
|
||||||
schedule_complete = false;
|
schedule_complete = false;
|
||||||
} else {
|
} else {
|
||||||
@ -833,7 +833,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
|
|||||||
const HloValue& assigned_buffer =
|
const HloValue& assigned_buffer =
|
||||||
*CHECK_NOTNULL(dynamic_cast<const HloValue*>(buffer_offset_size.first));
|
*CHECK_NOTNULL(dynamic_cast<const HloValue*>(buffer_offset_size.first));
|
||||||
for (const HloValue* new_value : hlo_buffer.values()) {
|
for (const HloValue* new_value : hlo_buffer.values()) {
|
||||||
if (assignment->liveness().hlo_ordering().MayInterfere(
|
if (assignment->hlo_ordering().MayInterfere(
|
||||||
assigned_buffer, *new_value, assignment->dataflow_analysis())) {
|
assigned_buffer, *new_value, assignment->dataflow_analysis())) {
|
||||||
VLOG(4) << "Can't assign: assignee " << assigned_buffer
|
VLOG(4) << "Can't assign: assignee " << assigned_buffer
|
||||||
<< " may interfere with " << new_value;
|
<< " may interfere with " << new_value;
|
||||||
@ -917,7 +917,7 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
|
|||||||
|
|
||||||
for (const HloValue* instruction_value : instruction_buffer.values()) {
|
for (const HloValue* instruction_value : instruction_buffer.values()) {
|
||||||
for (const HloValue* operand_value : operand_buffer.values()) {
|
for (const HloValue* operand_value : operand_buffer.values()) {
|
||||||
if (assignment->liveness().hlo_ordering().MayInterfere(
|
if (assignment->hlo_ordering().MayInterfere(
|
||||||
*instruction_value, *operand_value,
|
*instruction_value, *operand_value,
|
||||||
assignment->dataflow_analysis())) {
|
assignment->dataflow_analysis())) {
|
||||||
interfere = true;
|
interfere = true;
|
||||||
@ -1047,8 +1047,7 @@ Status BufferAssigner::AssignSingleHloBuffer(
|
|||||||
for (const HloValue* hlo_value : hlo_buffer->values()) {
|
for (const HloValue* hlo_value : hlo_buffer->values()) {
|
||||||
HloComputation* computation = hlo_value->instruction()->parent();
|
HloComputation* computation = hlo_value->instruction()->parent();
|
||||||
const bool has_sequential_order =
|
const bool has_sequential_order =
|
||||||
assignment->liveness().hlo_ordering().SequentialOrder(*computation) !=
|
assignment->hlo_ordering().SequentialOrder(*computation) != nullptr;
|
||||||
nullptr;
|
|
||||||
all_computations_have_sequential_order &= has_sequential_order;
|
all_computations_have_sequential_order &= has_sequential_order;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1125,10 +1124,9 @@ Status BufferAssigner::AssignBuffersForComputations(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const BufferLiveness& liveness = assignment->liveness();
|
|
||||||
for (const HloComputation* computation : computations) {
|
for (const HloComputation* computation : computations) {
|
||||||
const bool has_sequential_order =
|
const bool has_sequential_order =
|
||||||
liveness.hlo_ordering().SequentialOrder(*computation) != nullptr;
|
assignment->hlo_ordering().SequentialOrder(*computation) != nullptr;
|
||||||
if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
|
if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
|
||||||
// Every sequential computation must get an entry in the
|
// Every sequential computation must get an entry in the
|
||||||
// buffers_to_assign_sequentially map, even if we end up with an empty
|
// buffers_to_assign_sequentially map, even if we end up with an empty
|
||||||
@ -1197,7 +1195,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
|||||||
// Run the sequence of instructions through the heap simulator. The
|
// Run the sequence of instructions through the heap simulator. The
|
||||||
// heuristic that seems to give the best results is lazy-best-fit, with all
|
// heuristic that seems to give the best results is lazy-best-fit, with all
|
||||||
// runs of alloc / free calls sorted in decreasing size order.
|
// runs of alloc / free calls sorted in decreasing size order.
|
||||||
const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering();
|
const HloOrdering& hlo_ordering = assignment->hlo_ordering();
|
||||||
|
|
||||||
// Returns a heap algorithm that chooses the best result from several
|
// Returns a heap algorithm that chooses the best result from several
|
||||||
// algorithms.
|
// algorithms.
|
||||||
@ -1392,9 +1390,6 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
|
|||||||
BufferValue::SizeFunction buffer_size,
|
BufferValue::SizeFunction buffer_size,
|
||||||
LogicalBuffer::AlignmentFunction color_alignment,
|
LogicalBuffer::AlignmentFunction color_alignment,
|
||||||
HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
|
HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferLiveness> liveness,
|
|
||||||
BufferLiveness::Run(module, std::move(hlo_ordering)));
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||||
HloAliasAnalysis::Run(module, can_share_buffer));
|
HloAliasAnalysis::Run(module, can_share_buffer));
|
||||||
|
|
||||||
@ -1408,11 +1403,11 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
|
|||||||
// Can't use absl::make_unique because BufferAssignment constructor is
|
// Can't use absl::make_unique because BufferAssignment constructor is
|
||||||
// private.
|
// private.
|
||||||
std::unique_ptr<BufferAssignment> assignment(new BufferAssignment(
|
std::unique_ptr<BufferAssignment> assignment(new BufferAssignment(
|
||||||
module, std::move(liveness), std::move(buffer_size),
|
module, std::move(hlo_ordering), std::move(buffer_size),
|
||||||
std::move(color_alignment), std::move(alias_analysis)));
|
std::move(color_alignment), std::move(alias_analysis)));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(colorer_(&assignment->alias_analysis(),
|
TF_RETURN_IF_ERROR(
|
||||||
assignment->liveness().hlo_ordering()));
|
colorer_(&assignment->alias_analysis(), assignment->hlo_ordering()));
|
||||||
VLOG(3) << "After coloring:";
|
VLOG(3) << "After coloring:";
|
||||||
XLA_VLOG_LINES(3,
|
XLA_VLOG_LINES(3,
|
||||||
assignment->alias_analysis().dataflow_analysis().ToString());
|
assignment->alias_analysis().dataflow_analysis().ToString());
|
||||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
|||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/heap_simulator.h"
|
#include "tensorflow/compiler/xla/service/heap_simulator.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
|
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
|
||||||
@ -440,11 +439,6 @@ class BufferAssignment {
|
|||||||
bool HaveDisjointSlices(const HloInstruction* hlo_a,
|
bool HaveDisjointSlices(const HloInstruction* hlo_a,
|
||||||
const HloInstruction* hlo_b) const;
|
const HloInstruction* hlo_b) const;
|
||||||
|
|
||||||
// Returns the underlying points-to analysis used for this assignment.
|
|
||||||
const TuplePointsToAnalysis& points_to_analysis() const {
|
|
||||||
return liveness_->points_to_analysis();
|
|
||||||
}
|
|
||||||
|
|
||||||
const HloDataflowAnalysis& dataflow_analysis() const {
|
const HloDataflowAnalysis& dataflow_analysis() const {
|
||||||
return alias_analysis_->dataflow_analysis();
|
return alias_analysis_->dataflow_analysis();
|
||||||
}
|
}
|
||||||
@ -452,7 +446,7 @@ class BufferAssignment {
|
|||||||
HloAliasAnalysis& alias_analysis() const { return *alias_analysis_; }
|
HloAliasAnalysis& alias_analysis() const { return *alias_analysis_; }
|
||||||
|
|
||||||
// Returns the BufferLiveness object used to construct this assignment.
|
// Returns the BufferLiveness object used to construct this assignment.
|
||||||
const BufferLiveness& liveness() const { return *liveness_; }
|
const HloOrdering& hlo_ordering() const { return *hlo_ordering_; }
|
||||||
|
|
||||||
string ToString() const;
|
string ToString() const;
|
||||||
BufferAssignmentProto ToProto() const;
|
BufferAssignmentProto ToProto() const;
|
||||||
@ -483,12 +477,12 @@ class BufferAssignment {
|
|||||||
friend class BufferAssigner;
|
friend class BufferAssigner;
|
||||||
|
|
||||||
BufferAssignment(const HloModule* module,
|
BufferAssignment(const HloModule* module,
|
||||||
std::unique_ptr<BufferLiveness> liveness,
|
std::unique_ptr<HloOrdering> hlo_ordering,
|
||||||
BufferValue::SizeFunction buffer_size,
|
BufferValue::SizeFunction buffer_size,
|
||||||
LogicalBuffer::AlignmentFunction color_alignment,
|
LogicalBuffer::AlignmentFunction color_alignment,
|
||||||
std::unique_ptr<HloAliasAnalysis> alias_analysis)
|
std::unique_ptr<HloAliasAnalysis> alias_analysis)
|
||||||
: module_(module),
|
: module_(module),
|
||||||
liveness_(std::move(liveness)),
|
hlo_ordering_(std::move(hlo_ordering)),
|
||||||
buffer_size_(std::move(buffer_size)),
|
buffer_size_(std::move(buffer_size)),
|
||||||
color_alignment_(std::move(color_alignment)),
|
color_alignment_(std::move(color_alignment)),
|
||||||
alias_analysis_(std::move(alias_analysis)) {}
|
alias_analysis_(std::move(alias_analysis)) {}
|
||||||
@ -540,7 +534,8 @@ class BufferAssignment {
|
|||||||
allocation_index_for_value_;
|
allocation_index_for_value_;
|
||||||
|
|
||||||
const HloModule* module_;
|
const HloModule* module_;
|
||||||
const std::unique_ptr<BufferLiveness> liveness_;
|
|
||||||
|
const std::unique_ptr<HloOrdering> hlo_ordering_;
|
||||||
|
|
||||||
// Function which returns the buffer size for a given logical buffer (shape).
|
// Function which returns the buffer size for a given logical buffer (shape).
|
||||||
BufferValue::SizeFunction buffer_size_;
|
BufferValue::SizeFunction buffer_size_;
|
||||||
|
@ -1,179 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
// Defines the data returned by the XLA buffer assignment packages.
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
|
||||||
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/strings/str_format.h"
|
|
||||||
#include "absl/strings/str_join.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
|
||||||
|
|
||||||
namespace xla {
|
|
||||||
|
|
||||||
/* static */
|
|
||||||
StatusOr<std::unique_ptr<BufferLiveness>> BufferLiveness::Run(
|
|
||||||
const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering) {
|
|
||||||
std::unique_ptr<BufferLiveness> liveness(
|
|
||||||
new BufferLiveness(module, std::move(hlo_ordering)));
|
|
||||||
TF_RETURN_IF_ERROR(liveness->Analyze());
|
|
||||||
return std::move(liveness);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status BufferLiveness::Analyze() {
|
|
||||||
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_));
|
|
||||||
for (auto* computation : module_->computations()) {
|
|
||||||
if (computation->IsFusionComputation()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// Gather all instructions whose buffers might alias other instructions into
|
|
||||||
// the set aliased_buffers_. This includes those contained as a tuple
|
|
||||||
// element in other instruction's output.
|
|
||||||
for (const auto& instruction : computation->instructions()) {
|
|
||||||
for (const LogicalBuffer* aliased_buffer :
|
|
||||||
points_to_analysis_->GetPointsToSet(instruction)
|
|
||||||
.CreateFlattenedSet()) {
|
|
||||||
if (aliased_buffer->instruction() != instruction) {
|
|
||||||
aliased_buffers_.insert(aliased_buffer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (computation == module_->entry_computation()) {
|
|
||||||
const HloInstruction* root = computation->root_instruction();
|
|
||||||
maybe_live_out_buffers_ =
|
|
||||||
points_to_analysis_->GetPointsToSet(root).CreateFlattenedSet();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
XLA_VLOG_LINES(3, ToString());
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
string BufferLiveness::ToString() const {
|
|
||||||
std::vector<string> pieces;
|
|
||||||
pieces.push_back(
|
|
||||||
absl::StrFormat("BufferLiveness(module=%s):", module_->name()));
|
|
||||||
pieces.push_back("HloOrdering:");
|
|
||||||
pieces.push_back(hlo_ordering_->ToString());
|
|
||||||
pieces.push_back("Aliased buffers:");
|
|
||||||
for (const LogicalBuffer* buffer : aliased_buffers_) {
|
|
||||||
pieces.push_back(absl::StrFormat(" %s", buffer->ToString()));
|
|
||||||
}
|
|
||||||
pieces.push_back("Live out buffers:");
|
|
||||||
for (const LogicalBuffer* buffer : maybe_live_out_buffers_) {
|
|
||||||
pieces.push_back(absl::StrFormat(" %s", buffer->ToString()));
|
|
||||||
}
|
|
||||||
return absl::StrJoin(pieces, "\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
|
|
||||||
const LogicalBuffer& b) const {
|
|
||||||
TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(a));
|
|
||||||
TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(b));
|
|
||||||
|
|
||||||
if (!hlo_ordering_->ExecutesBefore(a.instruction(), b.instruction())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
|
|
||||||
// Every user of 'a' must be a predecessor of 'b' or 'b' itself.
|
|
||||||
for (auto user : alias.instruction()->users()) {
|
|
||||||
if (points_to_analysis().DoesNotUseOperandBuffer(alias.instruction(),
|
|
||||||
alias.index(), user)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (user != b.instruction() &&
|
|
||||||
!hlo_ordering_->ExecutesBefore(user, b.instruction())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the root instruction aliases the buffer 'a', the live range of 'a' is
|
|
||||||
// until the end of the computation and can never be strictly before another
|
|
||||||
// buffer nested in the same computation. This is needed to prevent the root
|
|
||||||
// instruction's buffers from being reused by later instructions even when
|
|
||||||
// the root is not the last instruction in the schedule.
|
|
||||||
if (alias.instruction()->parent()->root_instruction() ==
|
|
||||||
alias.instruction() &&
|
|
||||||
hlo_ordering_->call_graph().InstructionIsNestedIn(
|
|
||||||
b.instruction(), alias.instruction()->parent())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If 'b' is a user of 'a' then the buffers interfere unless 'a.instruction'
|
|
||||||
// and 'b.instruction' emit the same shape/layout, and 'b.instruction' meets
|
|
||||||
// the qualifications specified in CanShareOperandBufferWithUser.
|
|
||||||
for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
|
|
||||||
if (b.instruction()->IsUserOf(alias.instruction()) &&
|
|
||||||
!points_to_analysis().CanShareOperandBufferWithUser(
|
|
||||||
alias.instruction(), alias.index(), b.instruction(), b.index())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
bool IsEntryParameter(const HloInstruction* instruction) {
|
|
||||||
const HloComputation* computation = instruction->parent();
|
|
||||||
return instruction->opcode() == HloOpcode::kParameter &&
|
|
||||||
computation == computation->parent()->entry_computation();
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
bool BufferLiveness::MayInterfere(const LogicalBuffer& a,
|
|
||||||
const LogicalBuffer& b) const {
|
|
||||||
// Parameters live at the entry of the computation, thus always interfere with
|
|
||||||
// all other instructions inside the computation executing before them in the
|
|
||||||
// ordering.
|
|
||||||
const HloInstruction* a_instruction = a.instruction();
|
|
||||||
const HloInstruction* b_instruction = b.instruction();
|
|
||||||
if (a_instruction->opcode() == HloOpcode::kParameter &&
|
|
||||||
hlo_ordering_->call_graph().InstructionIsNestedIn(
|
|
||||||
b_instruction, a_instruction->parent()) &&
|
|
||||||
hlo_ordering_->ExecutesBefore(b_instruction, a_instruction)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (b_instruction->opcode() == HloOpcode::kParameter &&
|
|
||||||
hlo_ordering_->call_graph().InstructionIsNestedIn(
|
|
||||||
a_instruction, b_instruction->parent()) &&
|
|
||||||
hlo_ordering_->ExecutesBefore(a_instruction, b_instruction)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// Buffers without disjoint liveness may interfere.
|
|
||||||
return !live_range_strictly_before(a, b) && !live_range_strictly_before(b, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool BufferLiveness::MaybeLiveOut(const LogicalBuffer& buffer) const {
|
|
||||||
// Verify that a buffer is actually defined at the given instruction/index
|
|
||||||
// (eg, its not an alias of another buffer such as occurs with a bitcast).
|
|
||||||
TF_CHECK_OK(points_to_analysis_->VerifyBuffer(buffer));
|
|
||||||
return maybe_live_out_buffers_.count(&buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace xla
|
|
@ -1,114 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_LIVENESS_H_
|
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_LIVENESS_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
|
|
||||||
namespace xla {
|
|
||||||
|
|
||||||
// Class which computes liveness of the output buffers of HLOs and their
|
|
||||||
// interference.
|
|
||||||
class BufferLiveness {
|
|
||||||
public:
|
|
||||||
using Colorer = std::function<Status(const BufferLiveness& buffer_liveness)>;
|
|
||||||
|
|
||||||
// Constructs a buffer liveness object for the given module assuming the given
|
|
||||||
// HLO instruction ordering.
|
|
||||||
static StatusOr<std::unique_ptr<BufferLiveness>> Run(
|
|
||||||
const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering);
|
|
||||||
|
|
||||||
// Returns true if the live range of the buffer containing the output of 'a'
|
|
||||||
// may overlap with the live range of the buffer of 'b'. If instruction 'a'
|
|
||||||
// interferes with instruction 'b' then they cannot share the same buffer.
|
|
||||||
bool MayInterfere(const LogicalBuffer& a, const LogicalBuffer& b) const;
|
|
||||||
|
|
||||||
// Returns true if the buffer for the given instruction may be live out of the
|
|
||||||
// module. That is, the instruction's buffer may be included in the output of
|
|
||||||
// the entry computation.
|
|
||||||
bool MaybeLiveOut(const LogicalBuffer& buffer) const;
|
|
||||||
|
|
||||||
// Returns the complete set of buffers that may be live out of the module.
|
|
||||||
const PointsToSet::BufferSet& maybe_live_out_buffers() const {
|
|
||||||
return maybe_live_out_buffers_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the underlying points-to analysis used for this liveness analysis.
|
|
||||||
const TuplePointsToAnalysis& points_to_analysis() const {
|
|
||||||
return *points_to_analysis_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the underlying hlo ordering used for this liveness analysis.
|
|
||||||
const HloOrdering& hlo_ordering() const { return *hlo_ordering_; }
|
|
||||||
|
|
||||||
const HloModule& module() const { return *module_; }
|
|
||||||
|
|
||||||
string ToString() const;
|
|
||||||
|
|
||||||
static Colorer DefaultColorer() {
|
|
||||||
return [](const BufferLiveness& buffer_liveness) {
|
|
||||||
for (LogicalBuffer::Id id = 0;
|
|
||||||
id < buffer_liveness.points_to_analysis().num_logical_buffers();
|
|
||||||
id++) {
|
|
||||||
auto& buffer = buffer_liveness.points_to_analysis().logical_buffer(id);
|
|
||||||
buffer.set_color(LogicalBuffer::Color(0));
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
explicit BufferLiveness(const HloModule* module,
|
|
||||||
std::unique_ptr<HloOrdering> hlo_ordering)
|
|
||||||
: module_(module), hlo_ordering_(std::move(hlo_ordering)) {}
|
|
||||||
|
|
||||||
// Perform buffer liveness analysis. This method must be called prior to
|
|
||||||
// MayInterfere or MaybeLiveOut.
|
|
||||||
Status Analyze();
|
|
||||||
|
|
||||||
// Returns true if the live range of the buffer of 'a' is strictly before the
|
|
||||||
// live range of the buffer of 'b' (they do not overlap).
|
|
||||||
bool live_range_strictly_before(const LogicalBuffer& a,
|
|
||||||
const LogicalBuffer& b) const;
|
|
||||||
|
|
||||||
const HloModule* module_;
|
|
||||||
std::unique_ptr<HloOrdering> hlo_ordering_;
|
|
||||||
|
|
||||||
// Set of LogicalBuffers which are aliased in the output of other
|
|
||||||
// instructions. For example, a LogicalBuffer which is inserted into a tuple
|
|
||||||
// is considered to be aliased and will be in this set.
|
|
||||||
absl::flat_hash_set<const LogicalBuffer*> aliased_buffers_;
|
|
||||||
|
|
||||||
// LogicalBuffers that may be live out of the entry computation.
|
|
||||||
PointsToSet::BufferSet maybe_live_out_buffers_;
|
|
||||||
|
|
||||||
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace xla
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_LIVENESS_H_
|
|
@ -1,934 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
|
||||||
|
|
||||||
namespace xla {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
class BufferLivenessTest : public HloTestBase {
|
|
||||||
protected:
|
|
||||||
// Returns the LogicalBuffer defined at the given instruction and
|
|
||||||
// index. CHECKs if no buffer is defined at that point.
|
|
||||||
const LogicalBuffer& GetBuffer(const BufferLiveness& liveness,
|
|
||||||
const HloInstruction* instruction,
|
|
||||||
const ShapeIndex& index) {
|
|
||||||
const auto& pointed_to = liveness.points_to_analysis()
|
|
||||||
.GetPointsToSet(instruction)
|
|
||||||
.element(index);
|
|
||||||
CHECK_EQ(1, pointed_to.size());
|
|
||||||
CHECK_EQ(instruction, pointed_to[0]->instruction());
|
|
||||||
CHECK(index == pointed_to[0]->index());
|
|
||||||
return *pointed_to[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns true if the top-level buffers for instructions 'a' and 'b' may
|
|
||||||
// interfere. Precondition: 'a' and 'b' are array-shaped.
|
|
||||||
bool InstructionsMayInterfere(const BufferLiveness& liveness,
|
|
||||||
HloInstruction* a, HloInstruction* b) {
|
|
||||||
EXPECT_FALSE(a->shape().IsTuple());
|
|
||||||
EXPECT_FALSE(b->shape().IsTuple());
|
|
||||||
return liveness.MayInterfere(
|
|
||||||
GetBuffer(liveness, /*instruction=*/a, /*index=*/{}),
|
|
||||||
GetBuffer(liveness, /*instruction=*/b, /*index=*/{}));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns true if the tuple elements at 'index' for instructions 'a' and 'b'
|
|
||||||
// may interfere. Precondition: 'a' and 'b' are tuple-shaped, with equal
|
|
||||||
// tuple element sub-shapes.
|
|
||||||
bool TupleElementsMayInterfere(const BufferLiveness& liveness,
|
|
||||||
HloInstruction* a, HloInstruction* b,
|
|
||||||
const ShapeIndex& index) {
|
|
||||||
// Check that top-level shapes are tuple and tuple element shapes are equal.
|
|
||||||
EXPECT_TRUE(a->shape().IsTuple());
|
|
||||||
EXPECT_TRUE(b->shape().IsTuple());
|
|
||||||
EXPECT_TRUE(
|
|
||||||
ShapeUtil::Compatible(ShapeUtil::GetSubshape(a->shape(), index),
|
|
||||||
ShapeUtil::GetSubshape(b->shape(), index)));
|
|
||||||
// Lookup PointsTo set for instructions 'a' and 'b'.
|
|
||||||
auto& points_to_analysis = liveness.points_to_analysis();
|
|
||||||
const auto& points_to_a =
|
|
||||||
points_to_analysis.GetPointsToSet(a).element(index);
|
|
||||||
const auto& points_to_b =
|
|
||||||
points_to_analysis.GetPointsToSet(b).element(index);
|
|
||||||
// Make sure PointsTo sets for 'a' and 'b' are unambiguous.
|
|
||||||
EXPECT_EQ(1, points_to_a.size());
|
|
||||||
EXPECT_EQ(points_to_a.size(), points_to_b.size());
|
|
||||||
// Check interference.
|
|
||||||
return liveness.MayInterfere(*points_to_a[0], *points_to_b[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns true if the top-level buffers for the given instruction maybe
|
|
||||||
// liveout of the entry computation.
|
|
||||||
// Precondition: instruction is array-shaped.
|
|
||||||
bool InstructionMaybeLiveOut(const BufferLiveness& liveness,
|
|
||||||
HloInstruction* instruction) {
|
|
||||||
return liveness.MaybeLiveOut(
|
|
||||||
GetBuffer(liveness, instruction, /*index=*/{}));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<HloComputation> BuildDummyComputation() {
|
|
||||||
auto builder = HloComputation::Builder(TestName() + "_dummy");
|
|
||||||
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
|
|
||||||
return builder.Build();
|
|
||||||
}
|
|
||||||
|
|
||||||
const Shape vec_ = ShapeUtil::MakeShape(xla::F32, {42});
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(BufferLivenessTest, ElementwiseChain) {
|
|
||||||
// A simple chain of elementwise operations. No buffers should interfere.
|
|
||||||
//
|
|
||||||
// param --> negate -> exp -> log
|
|
||||||
//
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
auto param =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
|
|
||||||
auto negate = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
|
|
||||||
auto exp = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate));
|
|
||||||
auto log = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp));
|
|
||||||
|
|
||||||
auto module = CreateNewVerifiedModule();
|
|
||||||
module->AddEntryComputation(builder.Build());
|
|
||||||
|
|
||||||
auto liveness =
|
|
||||||
BufferLiveness::Run(
|
|
||||||
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, log));
|
|
||||||
|
|
||||||
// No buffers should interfere.
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, exp));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, log));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, log));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, negate));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, exp));
|
|
||||||
|
|
||||||
// Buffers should interfere with itself.
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, exp));
|
|
||||||
|
|
||||||
// Only log is live out.
|
|
||||||
EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param));
|
|
||||||
EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, negate));
|
|
||||||
EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, exp));
|
|
||||||
EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, log));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) {
|
|
||||||
// Two entry params, which interfere with each other.
|
|
||||||
//
|
|
||||||
// param0 --> negate ---------------\
|
|
||||||
// param1 --> exp --> add
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
auto param0 = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateParameter(0, vec_, "param0"));
|
|
||||||
auto param1 = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateParameter(1, vec_, "param1"));
|
|
||||||
auto negate = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param0));
|
|
||||||
auto exp = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param1));
|
|
||||||
auto add = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
|
|
||||||
|
|
||||||
auto module = CreateNewVerifiedModule();
|
|
||||||
HloComputation* entry = module->AddEntryComputation(builder.Build());
|
|
||||||
|
|
||||||
HloSchedule schedule(module.get());
|
|
||||||
schedule.set_sequence(entry, {param0, negate, param1, exp, add});
|
|
||||||
auto liveness =
|
|
||||||
BufferLiveness::Run(module.get(),
|
|
||||||
absl::make_unique<SequentialHloOrdering>(schedule))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
|
|
||||||
// Entry parameters interfere as if they are defined simultaneously at
|
|
||||||
// the very beginning.
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param0, param1));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, negate));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, exp));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, add));
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, param0));
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, negate));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, exp));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, add));
|
|
||||||
|
|
||||||
// Negate and exp still interfere.
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
|
|
||||||
|
|
||||||
// But {negate, add} and {exp, add} don't interfere.
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(BufferLivenessTest, EmbeddedComputationParameters) {
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule EmbeddedComputationParameters, is_scheduled=true
|
|
||||||
|
|
||||||
%EmbeddedComputationParameters_embedded (embedded_param0: f32[42], embedded_param1: f32[42]) -> (f32[42], f32[42]) {
|
|
||||||
%embedded_param0 = f32[42]{0} parameter(0)
|
|
||||||
%log = f32[42]{0} log(f32[42]{0} %embedded_param0)
|
|
||||||
%add = f32[42]{0} add(f32[42]{0} %log, f32[42]{0} %log)
|
|
||||||
%embedded_param1 = f32[42]{0} parameter(1)
|
|
||||||
ROOT %tuple = (f32[42]{0}, f32[42]{0}) tuple(f32[42]{0} %add, f32[42]{0} %embedded_param1)
|
|
||||||
}
|
|
||||||
|
|
||||||
ENTRY %EmbeddedComputationParameters (param0: f32[42], param1: f32[42]) -> (f32[42], f32[42]) {
|
|
||||||
%param0 = f32[42]{0} parameter(0)
|
|
||||||
%param1 = f32[42]{0} parameter(1)
|
|
||||||
ROOT %call = (f32[42]{0}, f32[42]{0}) call(f32[42]{0} %param0, f32[42]{0} %param1), to_apply=%EmbeddedComputationParameters_embedded
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
HloModuleConfig hlo_config;
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
|
||||||
std::unique_ptr<HloModule> module,
|
|
||||||
ParseAndReturnUnverifiedModule(hlo_string, hlo_config));
|
|
||||||
auto liveness =
|
|
||||||
BufferLiveness::Run(
|
|
||||||
module.get(),
|
|
||||||
absl::make_unique<SequentialHloOrdering>(module->schedule()))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
|
|
||||||
auto embedded_log = FindInstruction(module.get(), "log");
|
|
||||||
auto embedded_param0 = FindInstruction(module.get(), "embedded_param0");
|
|
||||||
auto embedded_param1 = FindInstruction(module.get(), "embedded_param1");
|
|
||||||
auto param0 = FindInstruction(module.get(), "param0");
|
|
||||||
auto param1 = FindInstruction(module.get(), "param1");
|
|
||||||
|
|
||||||
// Parameters should interfere with other instructions inside the computation.
|
|
||||||
EXPECT_TRUE(
|
|
||||||
InstructionsMayInterfere(*liveness, embedded_log, embedded_param1));
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_log, param0));
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_log, param1));
|
|
||||||
EXPECT_TRUE(
|
|
||||||
InstructionsMayInterfere(*liveness, embedded_param0, embedded_param1));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(BufferLivenessTest, InterferenceWithOuterRoot) {
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule InterferenceWithOuterRoot, is_scheduled=true
|
|
||||||
|
|
||||||
Emmbedded (embedded_param: f32[42]) -> f32[42] {
|
|
||||||
embedded_param = f32[42]{0} parameter(0)
|
|
||||||
multiply = f32[42]{0} multiply(embedded_param, embedded_param)
|
|
||||||
ROOT log = f32[42]{0} log(multiply)
|
|
||||||
}
|
|
||||||
|
|
||||||
ENTRY InterferenceWithOuterRoot {
|
|
||||||
param = f32[4096,4096]{1,0} parameter(0)
|
|
||||||
ROOT add = f32[4096,4096]{1,0} add(param, param)
|
|
||||||
call = f32[42]{0} call(param), to_apply=Emmbedded
|
|
||||||
}
|
|
||||||
|
|
||||||
)";
|
|
||||||
HloModuleConfig hlo_config;
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
|
||||||
std::unique_ptr<HloModule> module,
|
|
||||||
ParseAndReturnUnverifiedModule(hlo_string, hlo_config));
|
|
||||||
auto liveness =
|
|
||||||
BufferLiveness::Run(
|
|
||||||
module.get(),
|
|
||||||
absl::make_unique<SequentialHloOrdering>(module->schedule()))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
|
|
||||||
auto multiply = FindInstruction(module.get(), "multiply");
|
|
||||||
auto add = FindInstruction(module.get(), "add");
|
|
||||||
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, multiply, add));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(BufferLivenessTest, NonElementwiseOperand) {
|
|
||||||
// A chain of operations with two elementwise and one non-elementwise. The
|
|
||||||
// elementwise op should not interfere with its operand, while the
|
|
||||||
// non-elementwise op should interfere. Entry params always interfere.
|
|
||||||
//
|
|
||||||
// param --> exp -> negate -> reverse
|
|
||||||
//
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
auto param =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
|
|
||||||
auto exp = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
|
|
||||||
auto negate = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, exp));
|
|
||||||
auto reverse =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0}));
|
|
||||||
|
|
||||||
auto module = CreateNewVerifiedModule();
|
|
||||||
module->AddEntryComputation(builder.Build());
|
|
||||||
|
|
||||||
auto liveness =
|
|
||||||
BufferLiveness::Run(
|
|
||||||
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, reverse));
|
|
||||||
|
|
||||||
// Negate is elementwise, so doesn't interfere with its operand.
|
|
||||||
// Reverse is non-elementwise, so does interfere with its operand.
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate));
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, reverse));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(BufferLivenessTest, OverlappedBuffers) {
|
|
||||||
// Verify simultaneously live buffers interfere (exp and negate).
|
|
||||||
//
|
|
||||||
// param --> negate -> add
|
|
||||||
// \---> exp -----/
|
|
||||||
//
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
auto param =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
|
|
||||||
auto negate = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
|
|
||||||
auto exp = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
|
|
||||||
auto add = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
|
|
||||||
|
|
||||||
auto module = CreateNewVerifiedModule();
|
|
||||||
module->AddEntryComputation(builder.Build());
|
|
||||||
|
|
||||||
auto liveness =
|
|
||||||
BufferLiveness::Run(
|
|
||||||
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, exp));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
|
|
||||||
|
|
||||||
// Negate and exp interfere with each other, but not with add.
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) {
|
|
||||||
// Identical to the test OverlappedBuffer but using a sequential ordering of
|
|
||||||
// HLO instructions.
|
|
||||||
//
|
|
||||||
// param --> negate -> add
|
|
||||||
// \---> exp -----/
|
|
||||||
//
|
|
||||||
// Sequential order:
|
|
||||||
// param, negate, exp, add
|
|
||||||
//
|
|
||||||
// Liveness is identical to the DependencyHloOrdering.
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
auto param =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
|
|
||||||
auto negate = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
|
|
||||||
auto exp = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
|
|
||||||
auto add = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
|
|
||||||
|
|
||||||
auto module = CreateNewVerifiedModule();
|
|
||||||
auto computation = module->AddEntryComputation(builder.Build());
|
|
||||||
|
|
||||||
HloSchedule schedule(module.get());
|
|
||||||
schedule.set_sequence(computation, {param, negate, exp, add});
|
|
||||||
auto liveness =
|
|
||||||
BufferLiveness::Run(module.get(),
|
|
||||||
absl::make_unique<SequentialHloOrdering>(schedule))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
|
|
||||||
|
|
||||||
// Negate and exp interfere with each other, but not with add.
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) {
|
|
||||||
// Tests that when the root instruction is not the last instruction in the
|
|
||||||
// schedule, the live range of its buffers interfere with the buffers of the
|
|
||||||
// later instructions.
|
|
||||||
//
|
|
||||||
// Two sets of independent instructions are executed in the computation.
|
|
||||||
// param --> add (root)
|
|
||||||
// recv --> recv-done --> send --> send-done
|
|
||||||
//
|
|
||||||
// Sequential order:
|
|
||||||
// param, add (root), recv, recv-done, send, send-done
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
auto param =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
|
|
||||||
auto add = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param));
|
|
||||||
auto token = builder.AddInstruction(HloInstruction::CreateToken());
|
|
||||||
auto recv = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateRecv(vec_, token, /*channel_id=*/0));
|
|
||||||
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
|
|
||||||
auto send = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1));
|
|
||||||
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
|
|
||||||
|
|
||||||
auto module = CreateNewVerifiedModule();
|
|
||||||
auto computation = module->AddEntryComputation(builder.Build(add));
|
|
||||||
|
|
||||||
HloSchedule schedule(module.get());
|
|
||||||
schedule.set_sequence(computation,
|
|
||||||
{param, add, token, recv, recv_done, send, send_done});
|
|
||||||
TF_ASSERT_OK(schedule.Verify());
|
|
||||||
auto liveness =
|
|
||||||
BufferLiveness::Run(module.get(),
|
|
||||||
absl::make_unique<SequentialHloOrdering>(schedule))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
|
|
||||||
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
|
|
||||||
// Check the root instruction (add) buffer interferes with the recv buffer.
|
|
||||||
EXPECT_TRUE(
|
|
||||||
liveness->MayInterfere(GetBuffer(*liveness, add, /*index=*/{}),
|
|
||||||
GetBuffer(*liveness, recv, /*index=*/{0})));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(BufferLivenessTest, TupleLiveOut) {
|
|
||||||
// Verify MaybeLiveOut with nested tuples. Result of computation looks like:
|
|
||||||
//
|
|
||||||
// Tuple({Tuple({Negate(Param)}, Exp(Negate(Param)))})
|
|
||||||
//
|
|
||||||
// All values should be live out except Param.
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
auto param =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
|
|
||||||
auto negate = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
|
|
||||||
auto inner_tuple =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateTuple({negate}));
|
|
||||||
auto exp = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate));
|
|
||||||
auto outer_tuple =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp}));
|
|
||||||
|
|
||||||
auto module = CreateNewVerifiedModule();
|
|
||||||
module->AddEntryComputation(builder.Build());
|
|
||||||
|
|
||||||
auto liveness =
|
|
||||||
BufferLiveness::Run(
|
|
||||||
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
|
|
||||||
// All buffers should be live out except the param
|
|
||||||
EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param));
|
|
||||||
EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, negate));
|
|
||||||
EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, inner_tuple));
|
|
||||||
EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, exp));
|
|
||||||
EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, outer_tuple));
|
|
||||||
}
|
|
||||||
|
|
||||||
// bitcast liveout.
|
|
||||||
|
|
||||||
TEST_F(BufferLivenessTest, EmbeddedComputation) {
|
|
||||||
// Test MaybeLiveOut and MayInterfere for embedded computation.
|
|
||||||
auto module = CreateNewVerifiedModule();
|
|
||||||
|
|
||||||
auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
|
|
||||||
auto embedded_param = embedded_builder.AddInstruction(
|
|
||||||
HloInstruction::CreateParameter(0, vec_, "embedded_param"));
|
|
||||||
auto embedded_log = embedded_builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(vec_, HloOpcode::kLog, embedded_param));
|
|
||||||
|
|
||||||
auto embedded_computation =
|
|
||||||
module->AddEmbeddedComputation(embedded_builder.Build());
|
|
||||||
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
auto param =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
|
|
||||||
auto call = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateCall(vec_, {param}, embedded_computation));
|
|
||||||
|
|
||||||
module->AddEntryComputation(builder.Build());
|
|
||||||
|
|
||||||
auto liveness =
|
|
||||||
BufferLiveness::Run(
|
|
||||||
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
|
|
||||||
// Buffers in different computations should always interfere.
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_log, call));
|
|
||||||
EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_param, param));
|
|
||||||
EXPECT_FALSE(
|
|
||||||
InstructionsMayInterfere(*liveness, embedded_param, embedded_log));
|
|
||||||
|
|
||||||
// The only buffers for which MaybeLiveOut == true are those live out
|
|
||||||
// of the entry computation. Buffers live out of embedded computations should
|
|
||||||
// return false for this method.
|
|
||||||
EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, embedded_log));
|
|
||||||
EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, call));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
|
|
||||||
// Verify non top-level elements of a nested tuple constant are properly
|
|
||||||
// marked as liveout. Computation:
|
|
||||||
//
|
|
||||||
// GetTupleElement(0, TupleConstant({{0, 1}, {3}})
|
|
||||||
//
|
|
||||||
// Only the array buffers containing 0 and 1 are liveout of the
|
|
||||||
// computation. The buffer containing {0, 1} is copied by GetTupleElement, and
|
|
||||||
// the buffers containing {3} and 3 are dead.
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
Literal elements0[] = {LiteralUtil::CreateR0<int64>(0),
|
|
||||||
LiteralUtil::CreateR0<int64>(1)};
|
|
||||||
auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]});
|
|
||||||
Literal element1 = LiteralUtil::CreateR0<int64>(3);
|
|
||||||
auto inner_tuple1 = LiteralUtil::MakeTuple({&element1});
|
|
||||||
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
|
|
||||||
LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1})));
|
|
||||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
|
||||||
inner_tuple0.shape(), tuple_constant, 0));
|
|
||||||
|
|
||||||
auto module = CreateNewVerifiedModule();
|
|
||||||
module->AddEntryComputation(builder.Build());
|
|
||||||
|
|
||||||
auto liveness =
|
|
||||||
BufferLiveness::Run(
|
|
||||||
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
|
|
||||||
// Only the element buffers of the tuple constant which are pointed to by
|
|
||||||
// the GetTupleElement instruction should be liveout.
|
|
||||||
EXPECT_FALSE(liveness->MaybeLiveOut(
|
|
||||||
GetBuffer(*liveness, tuple_constant, /*index=*/{})));
|
|
||||||
EXPECT_TRUE(liveness->MaybeLiveOut(
|
|
||||||
GetBuffer(*liveness, tuple_constant, /*index=*/{0})));
|
|
||||||
EXPECT_TRUE(liveness->MaybeLiveOut(
|
|
||||||
GetBuffer(*liveness, tuple_constant, /*index=*/{0, 0})));
|
|
||||||
EXPECT_TRUE(liveness->MaybeLiveOut(
|
|
||||||
GetBuffer(*liveness, tuple_constant, /*index=*/{0, 1})));
|
|
||||||
EXPECT_FALSE(liveness->MaybeLiveOut(
|
|
||||||
GetBuffer(*liveness, tuple_constant, /*index=*/{1})));
|
|
||||||
EXPECT_FALSE(liveness->MaybeLiveOut(
|
|
||||||
GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0})));
|
|
||||||
EXPECT_FALSE(liveness->MaybeLiveOut(
|
|
||||||
GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0})));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(BufferLivenessTest, IndependentTupleElements) {
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
// Create param0 Tuple.
|
|
||||||
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
|
||||||
0,
|
|
||||||
ShapeUtil::MakeTupleShape(
|
|
||||||
{ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(S32, {4})}),
|
|
||||||
"param0"));
|
|
||||||
// Create independent computations for each tuple elememt.
|
|
||||||
|
|
||||||
// Tuple element0 computation:
|
|
||||||
// Add(GetTupleElement(tuple_param0, 0), const0)
|
|
||||||
auto tuple_element0_shape =
|
|
||||||
ShapeUtil::GetSubshape(tuple_param0->shape(), {0});
|
|
||||||
auto tuple_element0 =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
|
||||||
tuple_element0_shape, tuple_param0, 0));
|
|
||||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
|
||||||
LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
|
||||||
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
|
||||||
tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
|
|
||||||
|
|
||||||
// Tuple element1 computation:
|
|
||||||
// Add(GetTupleElement(tuple_param0, 1), const1)
|
|
||||||
auto tuple_element1_shape =
|
|
||||||
ShapeUtil::GetSubshape(tuple_param0->shape(), {1});
|
|
||||||
auto tuple_element1 =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
|
||||||
tuple_element1_shape, tuple_param0, 1));
|
|
||||||
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
|
|
||||||
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
|
|
||||||
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
|
|
||||||
tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1));
|
|
||||||
|
|
||||||
// Create output tuple.
|
|
||||||
auto tuple_root =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
|
|
||||||
|
|
||||||
auto module = CreateNewUnverifiedModule();
|
|
||||||
module->AddEntryComputation(BuildDummyComputation());
|
|
||||||
module->AddEmbeddedComputation(builder.Build());
|
|
||||||
|
|
||||||
auto liveness =
|
|
||||||
BufferLiveness::Run(
|
|
||||||
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
|
|
||||||
// We compare tuple element pairs that are input/output to the computation:
|
|
||||||
// 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0')
|
|
||||||
// 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1')
|
|
||||||
|
|
||||||
// Tuple output element 'add0' does not depend on input 'tuple_element1'.
|
|
||||||
// Tuple output element 'add1' does not depend on input 'tuple_element0'.
|
|
||||||
|
|
||||||
// Both element pair does not interfere, because there is no other dependency
|
|
||||||
// on the pairs tuple input element, and so liveness can compute that all
|
|
||||||
// users of the input tuple element execute before the associated output
|
|
||||||
// tuple element.
|
|
||||||
EXPECT_FALSE(
|
|
||||||
TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0}));
|
|
||||||
EXPECT_FALSE(
|
|
||||||
TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(BufferLivenessTest, DependentTupleElements) {
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
// Create param0 Tuple.
|
|
||||||
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
|
||||||
0,
|
|
||||||
ShapeUtil::MakeTupleShape(
|
|
||||||
{ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})}),
|
|
||||||
"param0"));
|
|
||||||
// Create dependent computations for each tuple elememt.
|
|
||||||
|
|
||||||
// Tuple element0 computation:
|
|
||||||
// Add(GetTupleElement(tuple_param0, 0), const0)
|
|
||||||
auto tuple_element0_shape =
|
|
||||||
ShapeUtil::GetSubshape(tuple_param0->shape(), {0});
|
|
||||||
auto tuple_element0 =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
|
||||||
tuple_element0_shape, tuple_param0, 0));
|
|
||||||
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
|
|
||||||
LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
|
|
||||||
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
|
||||||
tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
|
|
||||||
|
|
||||||
// Tuple element1 computation:
|
|
||||||
// Add(GetTupleElement(tuple_param0, 0), GetTupleElement(tuple_param0, 1))
|
|
||||||
auto tuple_element1_shape =
|
|
||||||
ShapeUtil::GetSubshape(tuple_param0->shape(), {1});
|
|
||||||
auto tuple_element1 =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
|
||||||
tuple_element1_shape, tuple_param0, 1));
|
|
||||||
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
|
|
||||||
tuple_element1_shape, HloOpcode::kAdd, tuple_element0, tuple_element1));
|
|
||||||
|
|
||||||
// Create output tuple.
|
|
||||||
auto tuple_root =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
|
|
||||||
|
|
||||||
auto module = CreateNewVerifiedModule();
|
|
||||||
module->AddEntryComputation(BuildDummyComputation());
|
|
||||||
module->AddEmbeddedComputation(builder.Build());
|
|
||||||
|
|
||||||
auto liveness =
|
|
||||||
BufferLiveness::Run(
|
|
||||||
module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
|
|
||||||
// We compare tuple element pairs that are input/output to the computation:
|
|
||||||
// 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0')
|
|
||||||
// 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1')
|
|
||||||
|
|
||||||
// The first tuple element pair output 'add0', has no dependency on second
|
|
||||||
// tuple element pairs input 'tuple_element1'.
|
|
||||||
|
|
||||||
// The second tuple element pair output 'add1', has a dependency on first
|
|
||||||
// tuple element pairs input 'tuple_element0'.
|
|
||||||
|
|
||||||
// The first tuple element pair does interfere, because liveness cannot
|
|
||||||
// compute that all references to 'tuple_element0' are executed before 'add0'
|
|
||||||
// (because of the depenency of 'add1' on 'tuple_element0').
|
|
||||||
EXPECT_TRUE(
|
|
||||||
TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0}));
|
|
||||||
|
|
||||||
// The second tuple element pair does not interfere, because there is no
|
|
||||||
// other dependency on 'tuple_element1', and so liveness can compute that
|
|
||||||
// all users execute before 'add1'.
|
|
||||||
EXPECT_FALSE(
|
|
||||||
TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}));
|
|
||||||
}
|
|
||||||
|
|
||||||
class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
|
|
||||||
protected:
|
|
||||||
// Builds and runs a computation (see test case computation graphs below).
|
|
||||||
std::unique_ptr<VerifiedHloModule> BuildModule(
|
|
||||||
const bool update_uses_tuple_element1, const bool fuse_gte0) {
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
// Create param0 Tuple.
|
|
||||||
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
|
|
||||||
Shape update_shape = ShapeUtil::MakeShape(F32, {3});
|
|
||||||
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
|
||||||
0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0"));
|
|
||||||
|
|
||||||
auto gte0 = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0));
|
|
||||||
|
|
||||||
auto gte1 = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
|
|
||||||
|
|
||||||
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
|
|
||||||
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
|
|
||||||
HloInstruction* slice = nullptr;
|
|
||||||
if (update_uses_tuple_element1) {
|
|
||||||
// Create a slice instruction as an additional user of 'gte1'.
|
|
||||||
slice = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateSlice(update_shape, gte1, {0}, {3}, {1}));
|
|
||||||
update = builder.AddInstruction(HloInstruction::CreateBinary(
|
|
||||||
update_shape, HloOpcode::kAdd, update, slice));
|
|
||||||
}
|
|
||||||
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
|
|
||||||
auto starts = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
|
|
||||||
auto dynamic_update_slice =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
|
||||||
data_shape, gte1, update, {starts}));
|
|
||||||
// Create output tuple.
|
|
||||||
builder.AddInstruction(
|
|
||||||
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
|
|
||||||
// Build module and get reference to entry computation.
|
|
||||||
auto module = CreateNewVerifiedModule();
|
|
||||||
module->AddEntryComputation(builder.Build());
|
|
||||||
auto* computation = module->entry_computation();
|
|
||||||
// Create fusion instruction based on number of tuple element 1 users.
|
|
||||||
if (update_uses_tuple_element1) {
|
|
||||||
computation->CreateFusionInstruction(
|
|
||||||
{dynamic_update_slice, starts, update, CHECK_NOTNULL(slice), gte1},
|
|
||||||
HloInstruction::FusionKind::kLoop);
|
|
||||||
} else {
|
|
||||||
computation->CreateFusionInstruction(
|
|
||||||
{dynamic_update_slice, starts, update, gte1},
|
|
||||||
HloInstruction::FusionKind::kLoop);
|
|
||||||
}
|
|
||||||
// Create fusion instruction for tuple element 0 (if requested).
|
|
||||||
if (fuse_gte0) {
|
|
||||||
computation->CreateFusionInstruction({gte0},
|
|
||||||
HloInstruction::FusionKind::kLoop);
|
|
||||||
}
|
|
||||||
return module;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns whether buffer interference is detected between tuple-shaped
|
|
||||||
// parameter and root instructions at tuple element 1.
|
|
||||||
bool Run(const bool update_uses_tuple_element1,
|
|
||||||
const bool fuse_gte0 = false) {
|
|
||||||
auto module = BuildModule(update_uses_tuple_element1, fuse_gte0);
|
|
||||||
// Run BufferLiveness on 'module'.
|
|
||||||
auto liveness = BufferLiveness::Run(
|
|
||||||
module.get(),
|
|
||||||
absl::make_unique<DependencyHloOrdering>(module.get()))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
// Return whether or not buffers interference is detected between
|
|
||||||
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
|
|
||||||
auto tuple_param0 = FindInstruction(module.get(), "param0");
|
|
||||||
auto tuple_root = module->entry_computation()->root_instruction();
|
|
||||||
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
|
|
||||||
}
|
|
||||||
bool RunWithHloDataflowAnalysis(const bool update_uses_tuple_element1,
|
|
||||||
const bool fuse_gte0 = false) {
|
|
||||||
auto module = BuildModule(update_uses_tuple_element1, fuse_gte0);
|
|
||||||
// Run BufferLiveness on 'module'.
|
|
||||||
auto dataflow = HloDataflowAnalysis::Run(*module).ConsumeValueOrDie();
|
|
||||||
auto hlo_ordering = absl::make_unique<DependencyHloOrdering>(module.get());
|
|
||||||
// Return whether or not buffers interference is detected between
|
|
||||||
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
|
|
||||||
auto tuple_param0 = FindInstruction(module.get(), "param0");
|
|
||||||
auto tuple_root = module->entry_computation()->root_instruction();
|
|
||||||
return hlo_ordering->MayInterfere(
|
|
||||||
dataflow->GetUniqueValueAt(tuple_param0, {1}),
|
|
||||||
dataflow->GetUniqueValueAt(tuple_root, {1}), *dataflow);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
|
|
||||||
// do not overlap with the following computation:
|
|
||||||
//
|
|
||||||
// Param0
|
|
||||||
// / \
|
|
||||||
// GTE(0) Fusion -----------> FusionParam
|
|
||||||
// | | |
|
|
||||||
// | | GTE(1) Const Const
|
|
||||||
// | | \ | /
|
|
||||||
// | | DynamicUpdateSlice // fused root
|
|
||||||
// \ /
|
|
||||||
// Tuple // computation root
|
|
||||||
//
|
|
||||||
TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) {
|
|
||||||
EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false));
|
|
||||||
EXPECT_FALSE(
|
|
||||||
RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases
|
|
||||||
// 'fusion1') do not overlap in the presence of another fusion instruction
|
|
||||||
// (which is a user of 'param0' at a different tuple index).
|
|
||||||
// BufferLiveness should detect no uses of Param0 at index {1} in Fusion0
|
|
||||||
// (because Fusion0 only uses Param0 at index {0}).
|
|
||||||
//
|
|
||||||
// Param0
|
|
||||||
// / \
|
|
||||||
// FusionParam <----- Fusion0 Fusion1 ------> FusionParam
|
|
||||||
// | | | |
|
|
||||||
// GTE(0) | | GTE(1) Const Const
|
|
||||||
// | | \ | /
|
|
||||||
// \ / DynamicUpdateSlice
|
|
||||||
// Tuple
|
|
||||||
//
|
|
||||||
TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) {
|
|
||||||
EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true));
|
|
||||||
EXPECT_FALSE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false,
|
|
||||||
/*fuse_gte0=*/true));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
|
|
||||||
// do overlap because GTE(1) has two users:
|
|
||||||
// 1) DynamicUpdateSlice at operand 0.
|
|
||||||
// 2) Slice at operand 0.
|
|
||||||
//
|
|
||||||
// Param0
|
|
||||||
// / \ Const
|
|
||||||
// / \ /
|
|
||||||
// GTE(0) Fusion -----------> FusionParam FusionParam
|
|
||||||
// | | | |
|
|
||||||
// | | GTE(1) /
|
|
||||||
// | | | \ /
|
|
||||||
// | | | Slice /
|
|
||||||
// | | | \ /
|
|
||||||
// | | | Add Const
|
|
||||||
// | | | | |
|
|
||||||
// | | DynamicUpdateSlice // fused root
|
|
||||||
// \ /
|
|
||||||
// Tuple // computation root
|
|
||||||
//
|
|
||||||
TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) {
|
|
||||||
EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true));
|
|
||||||
EXPECT_TRUE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/true));
|
|
||||||
}
|
|
||||||
|
|
||||||
class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
|
|
||||||
protected:
|
|
||||||
// Builds and runs a computation (see test case computation graphs below).
|
|
||||||
// Runs BufferLiveness on this computation.
|
|
||||||
// Returns whether buffer interference is detected between tuple-shaped
|
|
||||||
// parameter and root instructions at tuple element 1.
|
|
||||||
bool Run(const bool tuple_element1_has_two_uses) {
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
// Create param0 Tuple.
|
|
||||||
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
|
|
||||||
Shape update_shape = ShapeUtil::MakeShape(F32, {3});
|
|
||||||
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
|
||||||
0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0"));
|
|
||||||
|
|
||||||
auto gte0 = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0));
|
|
||||||
|
|
||||||
auto gte1 = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
|
|
||||||
|
|
||||||
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
|
|
||||||
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
|
|
||||||
|
|
||||||
if (tuple_element1_has_two_uses) {
|
|
||||||
// Add 'gte0' and 'gte1' to create another user of 'gte1'.
|
|
||||||
gte0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
|
||||||
data_shape, HloOpcode::kAdd, gte0, gte1));
|
|
||||||
}
|
|
||||||
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
|
|
||||||
auto starts = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
|
|
||||||
auto dynamic_update_slice =
|
|
||||||
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
|
||||||
data_shape, gte1, update, {starts}));
|
|
||||||
// Create output tuple.
|
|
||||||
auto tuple_root = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
|
|
||||||
// Build module and get reference to entry computation.
|
|
||||||
auto module = CreateNewVerifiedModule();
|
|
||||||
module->AddEntryComputation(BuildDummyComputation());
|
|
||||||
module->AddEmbeddedComputation(builder.Build());
|
|
||||||
// Run BufferLiveness on 'module'.
|
|
||||||
auto liveness = BufferLiveness::Run(
|
|
||||||
module.get(),
|
|
||||||
absl::make_unique<DependencyHloOrdering>(module.get()))
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
// Return whether or not buffers interference is detected between
|
|
||||||
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
|
|
||||||
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Tests that live ranges of buffers Param0[1] and Tuple[1] do not overlap in
|
|
||||||
// the following computation (because DynamicUpdateSlice (at operand 0) is the
|
|
||||||
// unique user):
|
|
||||||
//
|
|
||||||
// Parameter0
|
|
||||||
// | |
|
|
||||||
// GTE(0) GTE(1) Const Const
|
|
||||||
// | \ | /
|
|
||||||
// | DynamicUpdateSlice
|
|
||||||
// \ /
|
|
||||||
// Tuple
|
|
||||||
//
|
|
||||||
TEST_F(DynamicUpdateSliceLivenessTest, NoInterference) {
|
|
||||||
EXPECT_FALSE(Run(/*tuple_element1_has_two_uses=*/false));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tests that live ranges of buffers Param0[1] and Tuple[1] do overlap because
|
|
||||||
// GTE(1) has two users:
|
|
||||||
// 1) DynamicUpdateSlice at operand 0.
|
|
||||||
// 2) Add at operand 1.
|
|
||||||
//
|
|
||||||
// Parameter0
|
|
||||||
// | |
|
|
||||||
// GTE(0) GTE(1)
|
|
||||||
// | / |
|
|
||||||
// | / |
|
|
||||||
// Add | Const Const
|
|
||||||
// | | | |
|
|
||||||
// | DynamicUpdateSlice
|
|
||||||
// \ /
|
|
||||||
// Tuple
|
|
||||||
//
|
|
||||||
TEST_F(DynamicUpdateSliceLivenessTest, WithInterference) {
|
|
||||||
EXPECT_TRUE(Run(/*tuple_element1_has_two_uses=*/true));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
} // namespace xla
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user