Merge branch 'master' into sriniva2/threadpool_build
This commit is contained in:
commit
3692d790f8
2
.bazelrc
2
.bazelrc
@ -168,6 +168,8 @@ build:cuda_clang --action_env TF_CUDA_CLANG=1
|
|||||||
build:dbg --config=opt -c dbg
|
build:dbg --config=opt -c dbg
|
||||||
# for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360
|
# for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360
|
||||||
build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
|
build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
|
||||||
|
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
|
||||||
|
build:dbg --copt -DDEBUG_BUILD
|
||||||
|
|
||||||
build:tensorrt --action_env TF_NEED_TENSORRT=1
|
build:tensorrt --action_env TF_NEED_TENSORRT=1
|
||||||
|
|
||||||
|
144
RELEASE.md
144
RELEASE.md
@ -1,3 +1,147 @@
|
|||||||
|
# Release 2.2.0
|
||||||
|
|
||||||
|
TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update).
|
||||||
|
|
||||||
|
Coinciding with this change, new releases of [TensorFlow's Docker images](https://hub.docker.com/r/tensorflow/tensorflow/) provide Python 3 exclusively. Because all images now use Python 3, Docker tags containing `-py3` will no longer be provided and existing `-py3` tags like `latest-py3` will not be updated.
|
||||||
|
|
||||||
|
## Major Features and Improvements
|
||||||
|
|
||||||
|
* Replaced the scalar type for string tensors from `std::string` to `tensorflow::tstring` which is now ABI stable.
|
||||||
|
* A new Profiler for TF 2 for CPU/GPU/TPU. It offers both device and host performance analysis, including input pipeline and TF Ops. Optimization advisory is provided whenever possible. Please see [this tutorial](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras) and [guide](https://www.tensorflow.org/guide/profiler) for usage guidelines.
|
||||||
|
* Export C++ functions to Python using `pybind11` as opposed to `SWIG` as a part of our [deprecation of swig efforts](https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md).
|
||||||
|
* `tf.distribute`:
|
||||||
|
* Support added for global sync `BatchNormalization` by using the newly added `tf.keras.layers.experimental.SyncBatchNormalization` layer. This layer will sync `BatchNormalization` statistics every step across all replicas taking part in sync training.
|
||||||
|
* Performance improvements for GPU multi-worker distributed training using `tf.distribute.experimental.MultiWorkerMirroredStrategy`
|
||||||
|
* Update NVIDIA `NCCL` to `2.5.7-1` for better performance and performance tuning. Please see [nccl developer guide](https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html) for more information on this.
|
||||||
|
* Support gradient `allreduce` in `float16`. See this [example](https://github.com/tensorflow/models/blob/master/official/staging/training/grad_utils.py) usage.
|
||||||
|
* Experimental support of [all reduce gradient packing](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/CollectiveHints) to allow overlapping gradient aggregation with backward path computation.
|
||||||
|
* Deprecated `experimental_run_v2` method for distribution strategies and renamed the method `run` as it is no longer experimental.
|
||||||
|
* Add CompositeTensor support for DistributedIterators. This should help prevent unnecessary function retracing and memory leaks.
|
||||||
|
* `tf.keras`:
|
||||||
|
* `Model.fit` major improvements:
|
||||||
|
* You can now use custom training logic with `Model.fit` by overriding `Model.train_step`.
|
||||||
|
* Easily write state-of-the-art training loops without worrying about all of the features `Model.fit` handles for you (distribution strategies, callbacks, data formats, looping logic, etc)
|
||||||
|
* See the default [`Model.train_step`](https://github.com/tensorflow/tensorflow/blob/1381fc8e15e22402417b98e3881dfd409998daea/tensorflow/python/keras/engine/training.py#L540) for an example of what this function should look like. Same applies for validation and inference via `Model.test_step` and `Model.predict_step`.
|
||||||
|
* SavedModel uses its own `Model._saved_model_inputs_spec` attr now instead of
|
||||||
|
relying on `Model.inputs` and `Model.input_names`, which are no longer set for subclass Models.
|
||||||
|
This attr is set in eager, `tf.function`, and graph modes. This gets rid of the need for users to
|
||||||
|
manually call `Model._set_inputs` when using Custom Training Loops(CTLs).
|
||||||
|
* Dynamic shapes are supported for generators by calling the Model on the first batch we "peek" from the generator.
|
||||||
|
This used to happen implicitly in `Model._standardize_user_data`. Long-term, a solution where the
|
||||||
|
`DataAdapter` doesn't need to call the Model is probably preferable.
|
||||||
|
* The SavedModel format now supports all Keras built-in layers (including metrics, preprocessing layers, and stateful RNN layers)
|
||||||
|
* Update Keras batch normalization layer to use the running mean and average computation in the `fused_batch_norm`. You should see significant performance improvements when using `fused_batch_norm` in Eager mode.
|
||||||
|
|
||||||
|
* `tf.lite`:
|
||||||
|
* Enable TFLite experimental new converter by default.
|
||||||
|
* XLA
|
||||||
|
* XLA now builds and works on windows. All prebuilt packages come with XLA available.
|
||||||
|
* XLA can be [enabled for a `tf.function`](https://www.tensorflow.org/xla#explicit_compilation_with_tffunction
|
||||||
|
) with “compile or throw exception” semantics on CPU and GPU.
|
||||||
|
|
||||||
|
## Breaking Changes
|
||||||
|
* `tf.keras`:
|
||||||
|
* In `tf.keras.applications` the name of the "top" layer has been standardized to "predictions". This is only a problem if your code relies on the exact name of the layer.
|
||||||
|
* Huber loss function has been updated to be consistent with other Keras losses. It now computes mean over the last axis of per-sample losses before applying the reduction function.
|
||||||
|
* AutoGraph no longer converts functions passed to `tf.py_function`, `tf.py_func` and `tf.numpy_function`.
|
||||||
|
* Deprecating `XLA_CPU` and `XLA_GPU` devices with this release.
|
||||||
|
* Increasing the minimum bazel version to build TF to 2.0.0 to use Bazel's `cc_experimental_shared_library`.
|
||||||
|
* Keras compile/fit behavior for functional and subclassed models have been unified. Model properties such as `metrics`, `metrics_names` will now be available only after **training/evaluating the model on actual data** for functional models. `metrics` will **now include** model `loss` and output losses.`loss_functions` property has been removed from the model. This was an undocumented property that was accidentally public and has now been removed.
|
||||||
|
|
||||||
|
## Known Caveats
|
||||||
|
* The current TensorFlow release now **requires** [gast](https://pypi.org/project/gast/) version 0.3.3.
|
||||||
|
|
||||||
|
## Bug Fixes and Other Changes
|
||||||
|
* `tf.data`:
|
||||||
|
* Removed `autotune_algorithm` from experimental optimization options.
|
||||||
|
* TF Core:
|
||||||
|
* `tf.constant` always creates CPU tensors irrespective of the current device context.
|
||||||
|
* Eager `TensorHandles` maintain a list of mirrors for any copies to local or remote devices. This avoids any redundant copies due to op execution.
|
||||||
|
* For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer experimental and is available as simply `.ref()`.
|
||||||
|
* `pfor/vectorized_map`: Added support for vectorizing 56 more ops. Vectorizing `tf.cond` is also supported now.
|
||||||
|
* Set as much partial shape as we can infer statically within the gradient impl of the gather op.
|
||||||
|
* Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body functions are stateless. This allows multiple gradients while ops to run in parallel under distribution strategy.
|
||||||
|
* Speed up `GradientTape` in eager mode by auto-generating list of op inputs/outputs which are unused and hence not cached for gradient functions.
|
||||||
|
* Support `back_prop=False` in `while_v2` but mark it as deprecated.
|
||||||
|
* Improve error message when attempting to use `None` in data-dependent control flow.
|
||||||
|
* Add `RaggedTensor.numpy()`.
|
||||||
|
* Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow indexing into uniform dimensions.
|
||||||
|
* Update `tf.expand_dims` to always insert the new dimension as a non-ragged dimension.
|
||||||
|
* Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm` when `ids` is ragged.
|
||||||
|
* Allow `batch_dims==rank(indices)` in `tf.gather`.
|
||||||
|
* Add support for bfloat16 in `tf.print`.
|
||||||
|
* `tf.distribute`:
|
||||||
|
* Support `embedding_column` with variable-length input features for `MultiWorkerMirroredStrategy`.
|
||||||
|
* `tf.keras`:
|
||||||
|
* Added `experimental_aggregate_gradients` argument to `tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom gradient aggregation and processing aggregated gradients in custom training loop.
|
||||||
|
* Allow `pathlib.Path` paths for loading models via Keras API.
|
||||||
|
* `tf.function`/AutoGraph:
|
||||||
|
* AutoGraph is now available in `ReplicaContext.merge_call`, `Strategy.extended.update` and `Strategy.extended.update_non_slot`.
|
||||||
|
* Experimental support for shape invariants has been enabled in `tf.function`. See the API docs for `tf.autograph.experimental.set_loop_options` for additonal info.
|
||||||
|
* AutoGraph error messages now exclude frames corresponding to APIs internal to AutoGraph.
|
||||||
|
* Improve shape inference for `tf.function` input arguments to unlock more Grappler optimizations in TensorFlow 2.x.
|
||||||
|
* Improve automatic control dependency management of resources by allowing resource reads to occur in parallel and synchronizing only on writes.
|
||||||
|
* Fix execution order of multiple stateful calls to `experimental_run_v2` in `tf.function`.
|
||||||
|
* You can now iterate over `RaggedTensors` using a for loop inside `tf.function`.
|
||||||
|
* `tf.lite`:
|
||||||
|
* Migrated the `tf.lite` C inference API out of experimental into lite/c.
|
||||||
|
* Add an option to disallow `NNAPI` CPU / partial acceleration on Android 10
|
||||||
|
* TFLite Android AARs now include the C headers and APIs are required to use TFLite from native code.
|
||||||
|
* Refactors the delegate and delegate kernel sources to allow usage in the linter.
|
||||||
|
* Limit delegated ops to actually supported ones if a device name is specified or `NNAPI` CPU Fallback is disabled.
|
||||||
|
* TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`.
|
||||||
|
* TFLite's unpack op now supports boolean tensor inputs.
|
||||||
|
* Microcontroller and embedded code moved from experimental to main TensorFlow Lite folder
|
||||||
|
* Check for large TFLite tensors.
|
||||||
|
* Fix GPU delegate crash with C++17.
|
||||||
|
* Add 5D support to TFLite `strided_slice`.
|
||||||
|
* Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to be accelerated.
|
||||||
|
* Fix segmentation fault when running a model with LSTM nodes using `NNAPI` Delegate
|
||||||
|
* Fix `NNAPI` delegate failure when an operand for Maximum/Minimum operation is a scalar.
|
||||||
|
* Fix `NNAPI` delegate failure when Axis input for reduce operation is a scalar.
|
||||||
|
* Expose option to limit the number of partitions that will be delegated to `NNAPI`.
|
||||||
|
* If a target accelerator is specified, use its feature level to determine operations to delegate instead of SDK version.
|
||||||
|
* `tf.random`:
|
||||||
|
* Various random number generation improvements:
|
||||||
|
* Add a fast path for default `random_uniform`
|
||||||
|
* `random_seed` documentation improvement.
|
||||||
|
* `RandomBinomial` broadcasts and appends the sample shape to the left rather than the right.
|
||||||
|
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`, `tf.random.stateless_poisson`
|
||||||
|
* `tf.random.stateless_uniform` now supports unbounded sampling of `int` types.
|
||||||
|
* Math and Linear Algebra:
|
||||||
|
* Add `tf.linalg.LinearOperatorTridiag`.
|
||||||
|
* Add `LinearOperatorBlockLowerTriangular`
|
||||||
|
* Add broadcasting support to tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204), tf.math.invert_permutation.
|
||||||
|
* Add `tf.math.sobol_sample` op.
|
||||||
|
* Add `tf.math.xlog1py`.
|
||||||
|
* Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`.
|
||||||
|
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to `tf.signal`.
|
||||||
|
* TPU Enhancements:
|
||||||
|
* Refactor `TpuClusterResolver` to move shared logic to a separate pip package.
|
||||||
|
* Support configuring TPU software version from cloud tpu client.
|
||||||
|
* Allowed TPU embedding weight decay factor to be multiplied by learning rate.
|
||||||
|
* XLA Support:
|
||||||
|
* Add standalone XLA AOT runtime target + relevant .cc sources to pip package.
|
||||||
|
* Add check for memory alignment to MemoryAllocation::MemoryAllocation() on 32-bit ARM. This ensures a deterministic early exit instead of a hard to debug bus error later.
|
||||||
|
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to XLA header+object files and include them in your C++ programs.
|
||||||
|
* Enable `Igamma`, `Igammac` for XLA.
|
||||||
|
* Deterministic Op Functionality:
|
||||||
|
* XLA reduction emitter is deterministic when the environment variable `TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends deterministic `tf.nn.bias_add` back-prop functionality (and therefore also deterministic back-prop of bias-addition in Keras layers) to include when XLA JIT complilation is enabled.
|
||||||
|
* Fix problem, when running on a CUDA GPU and when either environment variable `TF_DETERMINSTIC_OPS` or environment variable `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer configurations led to an exception with the message "No algorithm worked!"
|
||||||
|
* Tracing and Debugging:
|
||||||
|
* Add source, destination name to `_send` traceme to allow easier debugging.
|
||||||
|
* Add traceme event to `fastpathexecute`.
|
||||||
|
* Other:
|
||||||
|
* Fix an issue with AUC.reset_states for multi-label AUC [#35852](https://github.com/tensorflow/tensorflow/issues/35852)
|
||||||
|
* Fix the TF upgrade script to not delete files when there is a parsing error and the output mode is `in-place`.
|
||||||
|
* Move `tensorflow/core:framework/*_pyclif` rules to `tensorflow/core/framework:*_pyclif`.
|
||||||
|
|
||||||
|
## Thanks to our Contributors
|
||||||
|
|
||||||
|
This release contains contributions from many people at Google, as well as:
|
||||||
|
|
||||||
|
372046933, 8bitmp3, aaronhma, Abin Shahab, Aditya Patwardhan, Agoniii, Ahti Kitsik, Alan Yee, Albin Joy, Alex Hoffman, Alexander Grund, Alexandre E. Eichenberger, Amit Kumar Jaiswal, amoitra, Andrew Anderson, Angus-Luo, Anthony Barbier, Anton Kachatkou, Anuj Rawat, archis, Arpan-Dhatt, Arvind Sundararajan, Ashutosh Hathidara, autoih, Bairen Yi, Balint Cristian, Bas Aarts, BashirSbaiti, Basit Ayantunde, Ben Barsdell, Benjamin Gaillard, boron, Brett Koonce, Bryan Cutler, Christian Goll, Christian Sachs, Clayne Robison, comet, Daniel Falbel, Daria Zhuravleva, darsh8200, David Truby, Dayananda-V, deepakm, Denis Khalikov, Devansh Singh, Dheeraj R Reddy, Diederik Van Liere, Diego Caballero, Dominic Jack, dothinking, Douman, Drake Gens, Duncan Riach, Ehsan Toosi, ekuznetsov139, Elena Zhelezina, elzino, Ending2015a, Eric Schweitz, Erik Zettel, Ethan Saadia, Eugene Kuznetsov, Evgeniy Zheltonozhskiy, Ewout Ter Hoeven, exfalso, FAIJUL, Fangjun Kuang, Fei Hu, Frank Laub, Frederic Bastien, Fredrik Knutsson, frreiss, Frédéric Rechtenstein, fsx950223, Gaurav Singh, gbaned, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, Hans Gaiser, Hans Pabst, Haoyu Wu, Harry Slatyer, hsahovic, Hugo, Hugo Sjöberg, IrinaM21, jacco, Jake Tae, Jean-Denis Lesage, Jean-Michel Gorius, Jeff Daily, Jens Elofsson, Jerry Shih, jerryyin, Jin Mingjian, Jinjing Zhou, JKIsaacLee, jojimonv, Jonathan Dekhtiar, Jose Ignacio Gomez, Joseph-Rance, Judd, Julian Gross, Kaixi Hou, Kaustubh Maske Patil, Keunwoo Choi, Kevin Hanselman, Khor Chean Wei, Kilaru Yasaswi Sri Chandra Gandhi, Koan-Sin Tan, Koki Ibukuro, Kristian Holsheimer, kurileo, Lakshay Tokas, Lee Netherton, leike666666, Leslie-Fang-Intel, Li, Guizi, LIUJIAN435, Lukas Geiger, Lyo Nguyen, madisetti, Maher Jendoubi, Mahmoud Abuzaina, Manuel Freiberger, Marcel Koester, Marco Jacopo Ferrarotti, Markus Franke, marload, Mbah-Javis, mbhuiyan, Meng Zhang, Michael Liao, MichaelKonobeev, Michal Tarnowski, Milan Straka, minoring, Mohamed Nour Abouelseoud, MoussaMM, Mrinal Jain, mrTsjolder, Måns Nilsson, Namrata Bhave, Nicholas Gao, Niels Ole Salscheider, nikochiko, Niranjan Hasabnis, Nishidha Panpaliya, nmostafa, Noah Trenaman, nuka137, Officium, Owen L - Sfe, Pallavi G, Paul Andrey, Peng Sun, Peng Wu, Phil Pearl, PhilipMay, pingsutw, Pooya Davoodi, PragmaTwice, pshiko, Qwerty71, R Gomathi, Rahul Huilgol, Richard Xiao, Rick Wierenga, Roberto Rosmaninho, ruchit2801, Rushabh Vasani, Sami, Sana Damani, Sarvesh Dubey, Sasan Jafarnejad, Sergii Khomenko, Shane Smiskol, Shaochen Shi, sharkdtu, Shawn Presser, ShengYang1, Shreyash Patodia, Shyam Sundar Dhanabalan, Siju Samuel, Somyajit Chakraborty Sam, Srihari Humbarwadi, srinivasan.narayanamoorthy, Srishti Yadav, Steph-En-M, Stephan Uphoff, Stephen Mugisha, SumanSudhir, Taehun Kim, Tamas Bela Feher, TengLu, Tetragramm, Thierry Herrmann, Tian Jin, tigertang, Tom Carchrae, Tom Forbes, Trent Lo, Victor Peng, vijayphoenix, Vincent Abriou, Vishal Bhola, Vishnuvardhan Janapati, vladbataev, VoVAllen, Wallyss Lima, Wen-Heng (Jack) Chung, wenxizhu, William D. Irons, William Zhang, Xiaoming (Jason) Cui, Xiaoquan Kong, Xinan Jiang, Yasir Modak, Yasuhiro Matsumoto, Yaxun (Sam) Liu, Yong Tang, Ytyt-Yt, yuan, Yuan Mingshuai, Yuan Tang, Yuki Ueda, Yusup, zhangshijin, zhuwenxi
|
||||||
|
|
||||||
# Release 2.0.1
|
# Release 2.0.1
|
||||||
|
|
||||||
## Bug Fixes and Other Changes
|
## Bug Fixes and Other Changes
|
||||||
|
36
configure.py
36
configure.py
@ -144,7 +144,7 @@ def write_to_bazelrc(line):
|
|||||||
|
|
||||||
|
|
||||||
def write_action_env_to_bazelrc(var_name, var):
|
def write_action_env_to_bazelrc(var_name, var):
|
||||||
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
|
write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var)))
|
||||||
|
|
||||||
|
|
||||||
def run_shell(cmd, allow_non_zero=False, stderr=None):
|
def run_shell(cmd, allow_non_zero=False, stderr=None):
|
||||||
@ -205,7 +205,7 @@ def setup_python(environ_cp):
|
|||||||
# Get PYTHON_BIN_PATH, default is the current running python.
|
# Get PYTHON_BIN_PATH, default is the current running python.
|
||||||
default_python_bin_path = sys.executable
|
default_python_bin_path = sys.executable
|
||||||
ask_python_bin_path = ('Please specify the location of python. [Default is '
|
ask_python_bin_path = ('Please specify the location of python. [Default is '
|
||||||
'%s]: ') % default_python_bin_path
|
'{}]: ').format(default_python_bin_path)
|
||||||
while True:
|
while True:
|
||||||
python_bin_path = get_from_env_or_user_or_default(environ_cp,
|
python_bin_path = get_from_env_or_user_or_default(environ_cp,
|
||||||
'PYTHON_BIN_PATH',
|
'PYTHON_BIN_PATH',
|
||||||
@ -215,9 +215,10 @@ def setup_python(environ_cp):
|
|||||||
if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
|
if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
|
||||||
break
|
break
|
||||||
elif not os.path.exists(python_bin_path):
|
elif not os.path.exists(python_bin_path):
|
||||||
print('Invalid python path: %s cannot be found.' % python_bin_path)
|
print('Invalid python path: {} cannot be found.'.format(python_bin_path))
|
||||||
else:
|
else:
|
||||||
print('%s is not executable. Is it the python binary?' % python_bin_path)
|
print('{} is not executable. Is it the python binary?'.format(
|
||||||
|
python_bin_path))
|
||||||
environ_cp['PYTHON_BIN_PATH'] = ''
|
environ_cp['PYTHON_BIN_PATH'] = ''
|
||||||
|
|
||||||
# Convert python path to Windows style before checking lib and version
|
# Convert python path to Windows style before checking lib and version
|
||||||
@ -236,7 +237,7 @@ def setup_python(environ_cp):
|
|||||||
default_python_lib_path = python_lib_paths[0]
|
default_python_lib_path = python_lib_paths[0]
|
||||||
python_lib_path = get_input(
|
python_lib_path = get_input(
|
||||||
'Please input the desired Python library path to use. '
|
'Please input the desired Python library path to use. '
|
||||||
'Default is [%s]\n' % python_lib_paths[0])
|
'Default is [{}]\n'.format(python_lib_paths[0]))
|
||||||
if not python_lib_path:
|
if not python_lib_path:
|
||||||
python_lib_path = default_python_lib_path
|
python_lib_path = default_python_lib_path
|
||||||
environ_cp['PYTHON_LIB_PATH'] = python_lib_path
|
environ_cp['PYTHON_LIB_PATH'] = python_lib_path
|
||||||
@ -252,7 +253,7 @@ def setup_python(environ_cp):
|
|||||||
# Set-up env variables used by python_configure.bzl
|
# Set-up env variables used by python_configure.bzl
|
||||||
write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
|
write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
|
||||||
write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
|
write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
|
||||||
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
|
write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path))
|
||||||
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
|
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
|
||||||
|
|
||||||
# If choosen python_lib_path is from a path specified in the PYTHONPATH
|
# If choosen python_lib_path is from a path specified in the PYTHONPATH
|
||||||
@ -266,7 +267,7 @@ def setup_python(environ_cp):
|
|||||||
with open(
|
with open(
|
||||||
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
|
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
|
||||||
'w') as f:
|
'w') as f:
|
||||||
f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
|
f.write('export PYTHON_BIN_PATH="{}"'.format(python_bin_path))
|
||||||
|
|
||||||
|
|
||||||
def reset_tf_configure_bazelrc():
|
def reset_tf_configure_bazelrc():
|
||||||
@ -320,11 +321,12 @@ def get_var(environ_cp,
|
|||||||
Raise the error to avoid infinitely looping.
|
Raise the error to avoid infinitely looping.
|
||||||
"""
|
"""
|
||||||
if not question:
|
if not question:
|
||||||
question = 'Do you wish to build TensorFlow with %s support?' % query_item
|
question = 'Do you wish to build TensorFlow with {} support?'.format(
|
||||||
|
query_item)
|
||||||
if not yes_reply:
|
if not yes_reply:
|
||||||
yes_reply = '%s support will be enabled for TensorFlow.' % query_item
|
yes_reply = '{} support will be enabled for TensorFlow.'.format(query_item)
|
||||||
if not no_reply:
|
if not no_reply:
|
||||||
no_reply = 'No %s' % yes_reply
|
no_reply = 'No {}'.format(yes_reply)
|
||||||
|
|
||||||
yes_reply += '\n'
|
yes_reply += '\n'
|
||||||
no_reply += '\n'
|
no_reply += '\n'
|
||||||
@ -368,7 +370,7 @@ def get_var(environ_cp,
|
|||||||
print(no_reply)
|
print(no_reply)
|
||||||
var = False
|
var = False
|
||||||
else:
|
else:
|
||||||
print('Invalid selection: %s' % user_input_origin)
|
print('Invalid selection: {}'.format(user_input_origin))
|
||||||
return var
|
return var
|
||||||
|
|
||||||
|
|
||||||
@ -479,13 +481,13 @@ def check_bazel_version(min_version, max_version):
|
|||||||
if which('bazel') is None:
|
if which('bazel') is None:
|
||||||
print('Cannot find bazel. Please install bazel.')
|
print('Cannot find bazel. Please install bazel.')
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
curr_version = run_shell(
|
|
||||||
['bazel', '--batch', '--bazelrc=/dev/null', 'version'])
|
|
||||||
|
|
||||||
for line in curr_version.split('\n'):
|
stderr = open(os.devnull, 'wb')
|
||||||
if 'Build label: ' in line:
|
curr_version = run_shell(['bazel', '--version'],
|
||||||
curr_version = line.split('Build label: ')[1]
|
allow_non_zero = True,
|
||||||
break
|
stderr = stderr)
|
||||||
|
if curr_version.startswith('bazel '):
|
||||||
|
curr_version = curr_version.split('bazel ')[1]
|
||||||
|
|
||||||
min_version_int = convert_version_to_int(min_version)
|
min_version_int = convert_version_to_int(min_version)
|
||||||
curr_version_int = convert_version_to_int(curr_version)
|
curr_version_int = convert_version_to_int(curr_version)
|
||||||
|
@ -517,6 +517,7 @@ package_group(
|
|||||||
"//perftools/accelerators/xprof/api/...",
|
"//perftools/accelerators/xprof/api/...",
|
||||||
"//third_party/py/autograph/...",
|
"//third_party/py/autograph/...",
|
||||||
"//third_party/swift/tensorflow/x10/...",
|
"//third_party/swift/tensorflow/x10/...",
|
||||||
|
"//third_party/swift/tensorflow_apis/...",
|
||||||
"//tensorflow/...",
|
"//tensorflow/...",
|
||||||
"//tensorflow_estimator/python/estimator/...",
|
"//tensorflow_estimator/python/estimator/...",
|
||||||
"//tensorflow_models/official/...",
|
"//tensorflow_models/official/...",
|
||||||
@ -529,6 +530,13 @@ package_group(name = "ndarray_tensor_allow_list")
|
|||||||
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
||||||
package_group(name = "composite_tensor_whitelist")
|
package_group(name = "composite_tensor_whitelist")
|
||||||
|
|
||||||
|
# Packages that use private types symbols, until they are exported.
|
||||||
|
# TODO(b/154650521) Remove.
|
||||||
|
package_group(
|
||||||
|
name = "types_whitelist",
|
||||||
|
packages = ["//learning/deepmind/tensorflow/replicator/..."],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "intel_binary_blob",
|
name = "intel_binary_blob",
|
||||||
data = if_mkl_ml(
|
data = if_mkl_ml(
|
||||||
|
@ -16,7 +16,6 @@ load(
|
|||||||
"//tensorflow/core/platform:build_config_root.bzl",
|
"//tensorflow/core/platform:build_config_root.bzl",
|
||||||
"tf_cuda_tests_tags",
|
"tf_cuda_tests_tags",
|
||||||
)
|
)
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
|
||||||
|
|
||||||
package(
|
package(
|
||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
@ -609,7 +608,6 @@ filegroup(
|
|||||||
],
|
],
|
||||||
exclude = [
|
exclude = [
|
||||||
"c_api_experimental.cc",
|
"c_api_experimental.cc",
|
||||||
"*c_api_tfrt*",
|
|
||||||
"*test*",
|
"*test*",
|
||||||
"*dlpack*",
|
"*dlpack*",
|
||||||
],
|
],
|
||||||
|
@ -38,7 +38,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||||
#include "tensorflow/c/tf_tensor_internal.h"
|
#include "tensorflow/c/tf_tensor_internal.h"
|
||||||
#ifdef PLATFORM_GOOGLE
|
#ifdef PLATFORM_GOOGLE
|
||||||
#include "tensorflow/c/eager/c_api_tfrt.h"
|
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||||
#endif
|
#endif
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
@ -924,7 +924,7 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
|||||||
context->GetDevicePlacementPolicy());
|
context->GetDevicePlacementPolicy());
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
||||||
tensorflow::Tensor tensor;
|
tensorflow::Tensor tensor;
|
||||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||||
if (!status->status.ok()) return nullptr;
|
if (!status->status.ok()) return nullptr;
|
||||||
|
@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
|||||||
// placed in memory of different devices or remote address spaces.
|
// placed in memory of different devices or remote address spaces.
|
||||||
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
|
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
// Indicates that the caller will not be using `h` any more.
|
// Indicates that the caller will not be using `h` any more.
|
||||||
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
|
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
|
||||||
|
@ -50,6 +50,13 @@ tensorflow::ServerDef GetServerDef(int num_tasks) {
|
|||||||
return GetServerDef("localhost", num_tasks);
|
return GetServerDef("localhost", num_tasks);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
|
||||||
|
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
|
||||||
|
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||||
|
job_def->mutable_tasks()->at(task_index) =
|
||||||
|
tensorflow::strings::StrCat("localhost:", port);
|
||||||
|
}
|
||||||
|
|
||||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||||
const std::vector<float>& expected_values) {
|
const std::vector<float>& expected_values) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
@ -101,6 +108,22 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
|
|||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read the value of variable `var` and save it into `out_value`.
|
||||||
|
void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var,
|
||||||
|
TFE_TensorHandle** out_value) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||||
|
TFE_OpAddInput(op, var, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(op, out_value, &num_retvals, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteOp(op);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
}
|
||||||
|
|
||||||
void TestRemoteExecuteChangeServerDef(bool async) {
|
void TestRemoteExecuteChangeServerDef(bool async) {
|
||||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||||
|
|
||||||
@ -243,6 +266,102 @@ TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
|
|||||||
TestRemoteExecuteUpdateServerDef(true);
|
TestRemoteExecuteUpdateServerDef(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) {
|
||||||
|
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||||
|
// This server def has the task index set to 0.
|
||||||
|
string serialized = server_def.SerializeAsString();
|
||||||
|
|
||||||
|
server_def.set_task_index(1);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server->Start().ok());
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||||
|
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
|
const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||||
|
|
||||||
|
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
|
||||||
|
EXPECT_NE(var_handle0, nullptr);
|
||||||
|
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
|
||||||
|
EXPECT_NE(var_handle1, nullptr);
|
||||||
|
|
||||||
|
TFE_TensorHandle* value_handle = nullptr;
|
||||||
|
ReadVariable(ctx, var_handle1, &value_handle);
|
||||||
|
CheckTFE_TensorHandleHasFloats(value_handle, {2});
|
||||||
|
TFE_DeleteTensorHandle(value_handle);
|
||||||
|
|
||||||
|
// Start a new worker to replace task:1
|
||||||
|
ReplaceTaskInServerDef(&server_def, 1);
|
||||||
|
server_def.set_task_index(1);
|
||||||
|
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||||
|
worker_server.release();
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server->Start().ok());
|
||||||
|
|
||||||
|
// Update server def to replace the remote device with the device info on the
|
||||||
|
// new worker (different incarnation ID).
|
||||||
|
server_def.set_task_index(0);
|
||||||
|
string serialized_update = server_def.SerializeAsString();
|
||||||
|
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||||
|
serialized_update.size(), status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
// The device of var_handle0 is local device which is the same before and
|
||||||
|
// after cluster update. Remove resource with valid device should succeed.
|
||||||
|
TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(op, var_handle0, status);
|
||||||
|
TFE_OpSetDevice(op, dev0_name, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
int num_retvals = 0;
|
||||||
|
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteOp(op);
|
||||||
|
|
||||||
|
// The device of var_handle1 is remote device, which was replaced during
|
||||||
|
// cluster update. Removing resource with invalid device should fail
|
||||||
|
// gracefully (i.e., with error status) instead of crashing with segfaults.
|
||||||
|
op = TFE_NewOp(ctx, "DestroyResourceOp", status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(op, var_handle1, status);
|
||||||
|
TFE_OpSetDevice(op, dev1_name, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
num_retvals = 0;
|
||||||
|
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||||
|
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteOp(op);
|
||||||
|
|
||||||
|
TFE_DeleteTensorHandle(var_handle0);
|
||||||
|
TFE_DeleteTensorHandle(var_handle1);
|
||||||
|
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||||
|
worker_server.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) {
|
||||||
|
TestRemoteExecuteUpdateServerDefResourceAccess(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) {
|
||||||
|
TestRemoteExecuteUpdateServerDefResourceAccess(true);
|
||||||
|
}
|
||||||
|
|
||||||
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||||
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
||||||
// when updating cluster with non-exsitent worker
|
// when updating cluster with non-exsitent worker
|
||||||
@ -282,6 +401,7 @@ void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
|||||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||||
job_def->mutable_tasks()->insert(
|
job_def->mutable_tasks()->insert(
|
||||||
{2, tensorflow::strings::StrCat("localhost:", port)});
|
{2, tensorflow::strings::StrCat("localhost:", port)});
|
||||||
|
server_def.set_task_index(0);
|
||||||
string serialized_update = server_def.SerializeAsString();
|
string serialized_update = server_def.SerializeAsString();
|
||||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||||
serialized_update.size(), status);
|
serialized_update.size(), status);
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||||
@ -638,3 +639,35 @@ TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
|
|||||||
return tensorflow::wrap(
|
return tensorflow::wrap(
|
||||||
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
|
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
||||||
|
TFE_TensorHandle** handles,
|
||||||
|
int* num_handles,
|
||||||
|
TF_Status* status) {
|
||||||
|
std::vector<tensorflow::TensorHandle*> tensor_handles;
|
||||||
|
tensor_handles.reserve(*num_handles);
|
||||||
|
for (int i = 0; i < *num_handles; ++i) {
|
||||||
|
tensor_handles.push_back(
|
||||||
|
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i])));
|
||||||
|
}
|
||||||
|
tensorflow::EagerContext* context =
|
||||||
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||||
|
tensorflow::TensorHandle* handle = nullptr;
|
||||||
|
status->status = tensorflow::TensorHandle::CreatePackedHandle(
|
||||||
|
std::move(tensor_handles), context, &handle);
|
||||||
|
return tensorflow::wrap(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||||
|
TF_Status* status) {
|
||||||
|
tensorflow::EagerContext* context =
|
||||||
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||||
|
context->SetAllowSoftPlacement(enable);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||||
|
TF_Status* status) {
|
||||||
|
tensorflow::EagerContext* context =
|
||||||
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||||
|
context->SetLogDevicePlacement(enable);
|
||||||
|
}
|
||||||
|
@ -541,6 +541,26 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
|
|||||||
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
|
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
|
||||||
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
|
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
|
||||||
|
|
||||||
|
// Create a packed TensorHandle with the given list of TensorHandles.
|
||||||
|
// If `handles` are on the same device, assign the same device to the packed
|
||||||
|
// handle; if `handles` are on different deivces, assign a CompositeDevice to
|
||||||
|
// it.
|
||||||
|
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
|
||||||
|
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Configure soft device placement policy for the eager executor. Note this
|
||||||
|
// policy is applied to any subsequent op executions.
|
||||||
|
TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
|
||||||
|
unsigned char enable,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Configure device placement policy logging for the eager executor. Note this
|
||||||
|
// policy is applied to any subsequent op executions.
|
||||||
|
TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
|
||||||
|
unsigned char enable,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
@ -351,6 +351,192 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
|
|||||||
/*heavy_load_on_streaming_rpc=*/true);
|
/*heavy_load_on_streaming_rpc=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add the values of three variables on three different tasks.
|
||||||
|
string AddVariablesFunction() {
|
||||||
|
tensorflow::FunctionDef def;
|
||||||
|
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||||
|
" signature {"
|
||||||
|
" name: 'AddVariablesFunction'"
|
||||||
|
" input_arg {"
|
||||||
|
" name: 'var'"
|
||||||
|
" type: DT_RESOURCE"
|
||||||
|
" }"
|
||||||
|
" output_arg {"
|
||||||
|
" name: 'sum'"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'read0'"
|
||||||
|
" op: 'ReadVariableOp'"
|
||||||
|
" input: 'var'"
|
||||||
|
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'dtype'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'read1'"
|
||||||
|
" op: 'ReadVariableOp'"
|
||||||
|
" input: 'var'"
|
||||||
|
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'dtype'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'read2'"
|
||||||
|
" op: 'ReadVariableOp'"
|
||||||
|
" input: 'var'"
|
||||||
|
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'dtype'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'add1'"
|
||||||
|
" op: 'Add'"
|
||||||
|
" input: 'read0:value:0'"
|
||||||
|
" input: 'read1:value:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'T'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'add2'"
|
||||||
|
" op: 'Add'"
|
||||||
|
" input: 'add1:z:0'"
|
||||||
|
" input: 'read2:value:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'T'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" ret {"
|
||||||
|
" key: 'sum'"
|
||||||
|
" value: 'add2:z:0'"
|
||||||
|
" }",
|
||||||
|
&def));
|
||||||
|
return def.SerializeAsString();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TestFunctionWithPackedInput) {
|
||||||
|
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||||
|
|
||||||
|
// This server def has the task index set to 0.
|
||||||
|
string serialized = server_def.SerializeAsString();
|
||||||
|
|
||||||
|
server_def.set_task_index(1);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server1->Start().ok());
|
||||||
|
|
||||||
|
server_def.set_task_index(2);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server2->Start().ok());
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
|
||||||
|
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
|
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||||
|
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||||
|
|
||||||
|
// Create one variable per task.
|
||||||
|
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
|
||||||
|
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
|
||||||
|
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
|
||||||
|
|
||||||
|
// Pack 3 variable handles into one TFE_TensorHandle.
|
||||||
|
int num_replicas = 3;
|
||||||
|
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
|
||||||
|
TFE_TensorHandle* packed_handle =
|
||||||
|
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
|
||||||
|
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
|
||||||
|
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
|
||||||
|
|
||||||
|
const string composite_device_name =
|
||||||
|
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
|
||||||
|
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
|
||||||
|
composite_device_name);
|
||||||
|
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
|
||||||
|
composite_device_name);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
// Register and run a function which returns the sum of 3 variables.
|
||||||
|
const string function_def = AddVariablesFunction();
|
||||||
|
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||||
|
status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(func, packed_handle, status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
ASSERT_EQ(1, num_retvals);
|
||||||
|
TFE_DeleteOp(func);
|
||||||
|
TFE_DeleteTensorHandle(packed_handle);
|
||||||
|
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
|
float sum = 0;
|
||||||
|
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||||
|
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||||
|
TF_DeleteTensor(t);
|
||||||
|
EXPECT_EQ(sum, 6.0);
|
||||||
|
|
||||||
|
TFE_DeleteTensorHandle(h0);
|
||||||
|
TFE_DeleteTensorHandle(h1);
|
||||||
|
TFE_DeleteTensorHandle(h2);
|
||||||
|
|
||||||
|
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||||
|
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_DeleteExecutor(executor);
|
||||||
|
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||||
|
worker_server1.release();
|
||||||
|
worker_server2.release();
|
||||||
|
}
|
||||||
|
|
||||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||||
|
|
||||||
|
@ -1132,51 +1132,6 @@ void BM_ExecuteFunction(int iters, int async) {
|
|||||||
}
|
}
|
||||||
BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1);
|
BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1);
|
||||||
|
|
||||||
TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
|
|
||||||
TF_Status* status) {
|
|
||||||
// Create the variable handle.
|
|
||||||
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
|
||||||
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
|
|
||||||
TFE_OpSetAttrString(op, "container", "", 0);
|
|
||||||
TFE_OpSetAttrString(op, "shared_name", "", 0);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_TensorHandle* var_handle = nullptr;
|
|
||||||
int num_retvals = 1;
|
|
||||||
TFE_Execute(op, &var_handle, &num_retvals, status);
|
|
||||||
TFE_DeleteOp(op);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
CHECK_EQ(1, num_retvals);
|
|
||||||
|
|
||||||
// Assign 'value' to it.
|
|
||||||
op = TFE_NewOp(ctx, "AssignVariableOp", status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
|
||||||
TFE_OpAddInput(op, var_handle, status);
|
|
||||||
|
|
||||||
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
|
|
||||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
|
|
||||||
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
|
|
||||||
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
|
||||||
|
|
||||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
|
||||||
value_handle(TFE_NewTensorHandle(t.get(), status),
|
|
||||||
TFE_DeleteTensorHandle);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
|
|
||||||
TFE_OpAddInput(op, value_handle.get(), status);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
|
|
||||||
num_retvals = 0;
|
|
||||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
|
||||||
TFE_DeleteOp(op);
|
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
|
||||||
CHECK_EQ(0, num_retvals);
|
|
||||||
|
|
||||||
return var_handle;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(CAPI, Variables) {
|
TEST(CAPI, Variables) {
|
||||||
// Variables use resource handles, so this is really a test for resource
|
// Variables use resource handles, so this is really a test for resource
|
||||||
// tensor handling.
|
// tensor handling.
|
||||||
@ -1186,7 +1141,7 @@ TEST(CAPI, Variables) {
|
|||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
|
TFE_TensorHandle* var_handle = TestVariable(ctx, 12.0);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||||
@ -1227,7 +1182,7 @@ void BM_ReadVariable(int iters) {
|
|||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
|
TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||||
@ -1248,6 +1203,8 @@ void BM_ReadVariable(int iters) {
|
|||||||
CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
|
CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
h = nullptr;
|
h = nullptr;
|
||||||
|
TFE_OpAddInput(op, var_handle, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
}
|
}
|
||||||
tensorflow::testing::StopTiming();
|
tensorflow::testing::StopTiming();
|
||||||
TFE_DeleteOp(op);
|
TFE_DeleteOp(op);
|
||||||
|
@ -133,6 +133,58 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
|
|||||||
return th;
|
return th;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||||
|
const tensorflow::string& device_name) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
// Create the variable handle.
|
||||||
|
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||||
|
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
|
||||||
|
TFE_OpSetAttrString(op, "container", "", 0);
|
||||||
|
TFE_OpSetAttrString(op, "shared_name", "", 0);
|
||||||
|
if (!device_name.empty()) {
|
||||||
|
TFE_OpSetDevice(op, device_name.c_str(), status);
|
||||||
|
}
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_TensorHandle* var_handle = nullptr;
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(op, &var_handle, &num_retvals, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_DeleteOp(op);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
CHECK_EQ(1, num_retvals);
|
||||||
|
|
||||||
|
// Assign 'value' to it.
|
||||||
|
op = TFE_NewOp(ctx, "AssignVariableOp", status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||||
|
TFE_OpAddInput(op, var_handle, status);
|
||||||
|
|
||||||
|
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
|
||||||
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
|
||||||
|
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
|
||||||
|
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
||||||
|
|
||||||
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||||
|
value_handle(TFE_NewTensorHandle(t.get(), status),
|
||||||
|
TFE_DeleteTensorHandle);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
|
||||||
|
TFE_OpAddInput(op, value_handle.get(), status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
|
||||||
|
num_retvals = 0;
|
||||||
|
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||||
|
TFE_DeleteOp(op);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
CHECK_EQ(0, num_retvals);
|
||||||
|
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
return var_handle;
|
||||||
|
}
|
||||||
|
|
||||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
|
|
||||||
|
@ -42,6 +42,11 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx);
|
|||||||
// Return a tensor handle containing a 3x2 matrix of floats
|
// Return a tensor handle containing a 3x2 matrix of floats
|
||||||
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
|
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
|
||||||
|
|
||||||
|
// Return a variable handle referring to a variable with the given initial value
|
||||||
|
// on the given device.
|
||||||
|
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||||
|
const tensorflow::string& device_name = "");
|
||||||
|
|
||||||
// Return an add op multiplying `a` by `b`.
|
// Return an add op multiplying `a` by `b`.
|
||||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ using tensorflow::string;
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST(UnifedCAPI, TestBasicEager) {
|
TEST(UnifiedCAPI, TestBasicEager) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
@ -81,7 +81,7 @@ TEST(UnifedCAPI, TestBasicEager) {
|
|||||||
TF_DeleteExecutionContext(ctx);
|
TF_DeleteExecutionContext(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnifedCAPI, TestBasicGraph) {
|
TEST(UnifiedCAPI, TestBasicGraph) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||||
@ -131,6 +131,7 @@ TEST(UnifedCAPI, TestBasicGraph) {
|
|||||||
string fn_name = "double";
|
string fn_name = "double";
|
||||||
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
||||||
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
|
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TF_DeleteAbstractTensor(placeholder_t);
|
TF_DeleteAbstractTensor(placeholder_t);
|
||||||
TF_DeleteAbstractTensor(output_t);
|
TF_DeleteAbstractTensor(output_t);
|
||||||
|
|
||||||
@ -184,7 +185,7 @@ TEST(UnifedCAPI, TestBasicGraph) {
|
|||||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
@ -200,7 +201,7 @@ TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
|||||||
TF_DeleteExecutionContext(ctx);
|
TF_DeleteExecutionContext(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||||
@ -221,7 +222,7 @@ TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
|||||||
TF_DeleteExecutionContext(graph_ctx);
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||||
@ -242,7 +243,7 @@ TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
|||||||
TF_DeleteExecutionContext(graph_ctx);
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||||
// Build an Eager context.
|
// Build an Eager context.
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
@ -288,7 +289,7 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
|||||||
TF_DeleteExecutionContext(graph_ctx);
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||||
|
@ -59,6 +59,20 @@ class AbstractContextInterface {
|
|||||||
virtual AbstractTensorInterface* CreateTensor(
|
virtual AbstractTensorInterface* CreateTensor(
|
||||||
DataType dtype, absl::Span<const int64> dim_sizes) = 0;
|
DataType dtype, absl::Span<const int64> dim_sizes) = 0;
|
||||||
|
|
||||||
|
typedef void (*MemoryReleaser)(void* data, size_t len, void* arg);
|
||||||
|
|
||||||
|
// Create a tensor instance from the given data buffer and description.
|
||||||
|
// `memory_releaser` will be called on destruction, and it's responsible for
|
||||||
|
// cleaning up the underlying buffer. `convert_string` indicates whether it
|
||||||
|
// has to handle tstring conversion. Expected to be removed once tstring
|
||||||
|
// migration is done.
|
||||||
|
virtual AbstractTensorInterface* CreateTensor(DataType dtype,
|
||||||
|
const int64_t* dims,
|
||||||
|
int num_dims, void* data,
|
||||||
|
size_t len, bool convert_string,
|
||||||
|
MemoryReleaser memory_releaser,
|
||||||
|
void* memory_releaser_arg) = 0;
|
||||||
|
|
||||||
// Create a handle to wrap and manage a Tensor
|
// Create a handle to wrap and manage a Tensor
|
||||||
virtual AbstractTensorHandleInterface* CreateLocalHandle(
|
virtual AbstractTensorHandleInterface* CreateLocalHandle(
|
||||||
AbstractTensorInterface* t) = 0;
|
AbstractTensorInterface* t) = 0;
|
||||||
|
@ -27,6 +27,7 @@ cc_library(
|
|||||||
name = "parallel_device",
|
name = "parallel_device",
|
||||||
srcs = [":sources"],
|
srcs = [":sources"],
|
||||||
hdrs = [":headers"],
|
hdrs = [":headers"],
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
@ -43,6 +44,7 @@ tf_cc_test(
|
|||||||
srcs = ["parallel_device_test.cc"],
|
srcs = ["parallel_device_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":parallel_device",
|
":parallel_device",
|
||||||
|
":parallel_device_ops",
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
"//tensorflow/c:c_api_experimental",
|
"//tensorflow/c:c_api_experimental",
|
||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
@ -52,3 +54,19 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Note: ParallelDevice-specific ops are experimental and not currently linked in
|
||||||
|
# to TensorFlow by default, just used in a few tests.
|
||||||
|
filegroup(
|
||||||
|
name = "parallel_device_ops_srcs",
|
||||||
|
srcs = ["parallel_device_ops.cc"],
|
||||||
|
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "parallel_device_ops",
|
||||||
|
srcs = [":parallel_device_ops_srcs"],
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
deps = ["//tensorflow/core:framework"],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
@ -92,6 +92,10 @@ class ParallelDevice {
|
|||||||
TFE_TensorHandle* tensor,
|
TFE_TensorHandle* tensor,
|
||||||
TF_Status* status) const;
|
TF_Status* status) const;
|
||||||
|
|
||||||
|
// A parallel tensor with scalar integers numbering component devices.
|
||||||
|
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
||||||
|
TF_Status* status) const;
|
||||||
|
|
||||||
// Takes a description of a single operation being executed on the
|
// Takes a description of a single operation being executed on the
|
||||||
// ParallelDevice, and in turn runs one operation per component device with
|
// ParallelDevice, and in turn runs one operation per component device with
|
||||||
// its corresponding inputs from the input ParallelTensors (or
|
// its corresponding inputs from the input ParallelTensors (or
|
||||||
@ -208,6 +212,46 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
|||||||
status);
|
status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||||
|
TFE_Context* context, TF_Status* status) const {
|
||||||
|
// TODO(allenl): We could cache DeviceIDs (keyed by context).
|
||||||
|
std::vector<TensorHandlePtr> components;
|
||||||
|
components.reserve(underlying_devices_.size());
|
||||||
|
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||||
|
++device_index) {
|
||||||
|
int64_t* device_id = new int64_t;
|
||||||
|
*device_id = device_index;
|
||||||
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||||
|
TF_NewTensor(
|
||||||
|
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||||
|
sizeof(int64_t),
|
||||||
|
[](void* data, size_t, void* arg) {
|
||||||
|
delete reinterpret_cast<int64_t*>(data);
|
||||||
|
},
|
||||||
|
nullptr),
|
||||||
|
TF_DeleteTensor);
|
||||||
|
// TODO(allenl): Here and when executing regular operations, we could hold
|
||||||
|
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
||||||
|
// device names repeatedly.
|
||||||
|
OpPtr const_op(TFE_NewOp(context, "Const", status));
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
||||||
|
status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
||||||
|
TFE_TensorHandle* device_handle;
|
||||||
|
int num_outputs = 1;
|
||||||
|
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
components.emplace_back(device_handle);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
|
}
|
||||||
|
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||||
|
status);
|
||||||
|
}
|
||||||
|
|
||||||
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||||
@ -282,6 +326,13 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
|||||||
}
|
}
|
||||||
result.emplace(std::move(outputs));
|
result.emplace(std::move(outputs));
|
||||||
return result;
|
return result;
|
||||||
|
} else if (operation_name == std::string("DeviceID")) {
|
||||||
|
std::vector<MaybeParallelTensorOwned> result_content;
|
||||||
|
result_content.reserve(1);
|
||||||
|
result_content.push_back(DeviceIDs(context, status));
|
||||||
|
if (TF_GetCode(status) != TF_OK) return result;
|
||||||
|
result.emplace(std::move(result_content));
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||||
maybe_parallel_results(
|
maybe_parallel_results(
|
||||||
|
26
tensorflow/c/eager/parallel_device/parallel_device_ops.cc
Normal file
26
tensorflow/c/eager/parallel_device/parallel_device_ops.cc
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
|
||||||
|
// TODO(allenl): Figure out if we need this op, and if so whether we should move
|
||||||
|
// it to core TF. Right now the eager C API does some checking of op
|
||||||
|
// registrations before calling into custom devices, but we may be able to avoid
|
||||||
|
// that.
|
||||||
|
REGISTER_OP("DeviceID")
|
||||||
|
.Output("device_id: int64")
|
||||||
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
|
@ -278,14 +278,15 @@ TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assert that `handle` is equal to `expected_value`.
|
// Assert that `handle` is equal to `expected_value`.
|
||||||
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
|
template <typename value_type>
|
||||||
|
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
||||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
ASSERT_EQ(expected_value,
|
EXPECT_EQ(expected_value,
|
||||||
*static_cast<float*>(TF_TensorData(value_zero.get())));
|
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <std::size_t num_devices>
|
template <std::size_t num_devices>
|
||||||
@ -343,8 +344,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
|||||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
AssertScalarFloatEq(components[0].get(), 20.);
|
ExpectScalarEq<float>(components[0].get(), 20.);
|
||||||
AssertScalarFloatEq(components[1].get(), 20.);
|
ExpectScalarEq<float>(components[1].get(), 20.);
|
||||||
|
|
||||||
std::string first_device =
|
std::string first_device =
|
||||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||||
@ -373,8 +374,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
|||||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
AssertScalarFloatEq(components[0].get(), 23.);
|
ExpectScalarEq<float>(components[0].get(), 23.);
|
||||||
AssertScalarFloatEq(components[1].get(), 18.);
|
ExpectScalarEq<float>(components[1].get(), 18.);
|
||||||
|
|
||||||
std::string first_device =
|
std::string first_device =
|
||||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||||
@ -383,6 +384,32 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
|||||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||||
ASSERT_EQ(underlying_devices[1], second_device);
|
ASSERT_EQ(underlying_devices[1], second_device);
|
||||||
}
|
}
|
||||||
|
// Compute the device ID twice and verify the result
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||||
|
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
TFE_TensorHandle* result_handle;
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
std::array<TensorHandlePtr, 2> components;
|
||||||
|
ExtractPerDeviceValues(context, result_handle, &components, status.get());
|
||||||
|
TFE_DeleteTensorHandle(result_handle);
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
ExpectScalarEq<int64_t>(components[0].get(), 0);
|
||||||
|
ExpectScalarEq<int64_t>(components[1].get(), 1);
|
||||||
|
std::string first_device =
|
||||||
|
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||||
|
ASSERT_EQ(underlying_devices[0], first_device);
|
||||||
|
std::string second_device =
|
||||||
|
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||||
|
ASSERT_EQ(underlying_devices[1], second_device);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(PARALLEL_DEVICE, TestBasicCPU) {
|
TEST(PARALLEL_DEVICE, TestBasicCPU) {
|
||||||
@ -498,8 +525,8 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
|
|||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
// The value of the original tensor is replicated on each device.
|
// The value of the original tensor is replicated on each device.
|
||||||
AssertScalarFloatEq(components[0].get(), 3.);
|
ExpectScalarEq<float>(components[0].get(), 3.);
|
||||||
AssertScalarFloatEq(components[1].get(), 3.);
|
ExpectScalarEq<float>(components[1].get(), 3.);
|
||||||
|
|
||||||
// Verify that the mirrors are placed on the component devices.
|
// Verify that the mirrors are placed on the component devices.
|
||||||
std::string first_device =
|
std::string first_device =
|
||||||
@ -630,7 +657,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
|||||||
&second_components, status.get());
|
&second_components, status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
AssertScalarFloatEq(second_components[1].get(), 9.);
|
ExpectScalarEq<float>(second_components[1].get(), 9.);
|
||||||
|
|
||||||
// Verify that the mirrors are placed on the component devices.
|
// Verify that the mirrors are placed on the component devices.
|
||||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||||
@ -644,8 +671,8 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
|||||||
std::array<TensorHandlePtr, 2> first_components;
|
std::array<TensorHandlePtr, 2> first_components;
|
||||||
ExtractPerDeviceValues(context.get(), second_components[0].get(),
|
ExtractPerDeviceValues(context.get(), second_components[0].get(),
|
||||||
&first_components, status.get());
|
&first_components, status.get());
|
||||||
AssertScalarFloatEq(first_components[0].get(), 3.);
|
ExpectScalarEq<float>(first_components[0].get(), 3.);
|
||||||
AssertScalarFloatEq(first_components[1].get(), 6.);
|
ExpectScalarEq<float>(first_components[1].get(), 6.);
|
||||||
|
|
||||||
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
|
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
|
||||||
status.get());
|
status.get());
|
||||||
@ -806,8 +833,8 @@ TEST(PARALLEL_DEVICE, TestCollective) {
|
|||||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||||
status.get());
|
status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
AssertScalarFloatEq(result_components[0].get(), 3.);
|
ExpectScalarEq<float>(result_components[0].get(), 3.);
|
||||||
AssertScalarFloatEq(result_components[1].get(), 3.);
|
ExpectScalarEq<float>(result_components[1].get(), 3.);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RegisterCollectiveMulFunction(TFE_Context* context,
|
void RegisterCollectiveMulFunction(TFE_Context* context,
|
||||||
@ -909,8 +936,8 @@ TEST(PARALLEL_DEVICE, TestFunction) {
|
|||||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||||
status.get());
|
status.get());
|
||||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
|
ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
|
||||||
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
|
ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
|
||||||
|
|
||||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||||
result_components[0].get(), status.get());
|
result_components[0].get(), status.get());
|
||||||
|
@ -178,7 +178,7 @@ cc_library_with_android_deps(
|
|||||||
name = "ops",
|
name = "ops",
|
||||||
srcs = ["framework/ops.cc"],
|
srcs = ["framework/ops.cc"],
|
||||||
hdrs = ["framework/ops.h"],
|
hdrs = ["framework/ops.h"],
|
||||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
@ -197,7 +197,7 @@ cc_library_with_android_deps(
|
|||||||
"framework/scope_internal.h",
|
"framework/scope_internal.h",
|
||||||
],
|
],
|
||||||
hdrs = ["framework/scope.h"],
|
hdrs = ["framework/scope.h"],
|
||||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||||
common_deps = [
|
common_deps = [
|
||||||
":ops",
|
":ops",
|
||||||
],
|
],
|
||||||
@ -237,7 +237,7 @@ cc_library_with_android_deps(
|
|||||||
name = "client_session",
|
name = "client_session",
|
||||||
srcs = ["client/client_session.cc"],
|
srcs = ["client/client_session.cc"],
|
||||||
hdrs = ["client/client_session.h"],
|
hdrs = ["client/client_session.h"],
|
||||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||||
common_deps = [
|
common_deps = [
|
||||||
":ops",
|
":ops",
|
||||||
":scope",
|
":scope",
|
||||||
@ -275,7 +275,7 @@ cc_library_with_android_deps(
|
|||||||
srcs = ["ops/const_op.cc"],
|
srcs = ["ops/const_op.cc"],
|
||||||
hdrs = ["ops/const_op.h"],
|
hdrs = ["ops/const_op.h"],
|
||||||
android_deps = [
|
android_deps = [
|
||||||
"//tensorflow/core:android_tensorflow_lib",
|
"//tensorflow/core:portable_tensorflow_lib",
|
||||||
],
|
],
|
||||||
common_deps = [
|
common_deps = [
|
||||||
":ops",
|
":ops",
|
||||||
@ -304,7 +304,7 @@ cc_library_with_android_deps(
|
|||||||
srcs = ["ops/while_loop.cc"],
|
srcs = ["ops/while_loop.cc"],
|
||||||
hdrs = ["ops/while_loop.h"],
|
hdrs = ["ops/while_loop.h"],
|
||||||
android_deps = [
|
android_deps = [
|
||||||
"//tensorflow/core:android_tensorflow_lib",
|
"//tensorflow/core:portable_tensorflow_lib",
|
||||||
],
|
],
|
||||||
common_deps = [
|
common_deps = [
|
||||||
":cc_ops",
|
":cc_ops",
|
||||||
|
@ -57,7 +57,22 @@ cc_library(
|
|||||||
"tensor.h",
|
"tensor.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":status",
|
||||||
"//tensorflow/c:tf_datatype",
|
"//tensorflow/c:tf_datatype",
|
||||||
"//tensorflow/c:tf_tensor",
|
"//tensorflow/c:tf_tensor",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensorhandle",
|
||||||
|
hdrs = [
|
||||||
|
"tensorhandle.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":runtime",
|
||||||
|
":status",
|
||||||
|
":tensor",
|
||||||
|
"//tensorflow/c/eager:c_api",
|
||||||
|
"//tensorflow/c/eager:c_api_experimental",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
|
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
|
||||||
@ -40,6 +41,7 @@ class Runtime {
|
|||||||
private:
|
private:
|
||||||
friend class RuntimeBuilder;
|
friend class RuntimeBuilder;
|
||||||
friend class SavedModelAPI;
|
friend class SavedModelAPI;
|
||||||
|
friend class TensorHandle;
|
||||||
|
|
||||||
// Wraps a TFE_Context. Takes ownership of ctx.
|
// Wraps a TFE_Context. Takes ownership of ctx.
|
||||||
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
|
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
|
||||||
@ -63,6 +65,7 @@ class Runtime {
|
|||||||
};
|
};
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
|
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
|
||||||
@ -79,6 +80,7 @@ inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// Status is a wrapper around an error code and an optional error message.
|
// Status is a wrapper around an error code and an optional error message.
|
||||||
@ -57,6 +58,7 @@ class Status {
|
|||||||
friend class RuntimeBuilder;
|
friend class RuntimeBuilder;
|
||||||
friend class Runtime;
|
friend class Runtime;
|
||||||
friend class SavedModelAPI;
|
friend class SavedModelAPI;
|
||||||
|
friend class TensorHandle;
|
||||||
|
|
||||||
// Wraps a TF_Status*, and takes ownership of it.
|
// Wraps a TF_Status*, and takes ownership of it.
|
||||||
explicit Status(TF_Status* status) : status_(status) {}
|
explicit Status(TF_Status* status) : status_(status) {}
|
||||||
@ -88,6 +90,7 @@ inline void Status::SetStatus(TF_Code code, const std::string& msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
||||||
|
@ -19,30 +19,53 @@ limitations under the License.
|
|||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// Tensor represents an n-dimensional array of values.
|
// Tensor represents an n-dimensional array of values.
|
||||||
class Tensor {
|
class Tensor {
|
||||||
public:
|
public:
|
||||||
// TODO(bmzhao): Add a factory function that constructs a Tensor from a char
|
using DeleterCallback = std::function<void(void*, size_t)>;
|
||||||
// buffer, with an options struct (to specify the buffer's layout, device?,
|
|
||||||
// whether to create a TFRT or TF tensor, whether we should take ownership of
|
// Constructs a Tensor from user provided buffer.
|
||||||
// the memory, etc). This requires extending TF_NewTensor with an options
|
//
|
||||||
// struct:
|
// Params:
|
||||||
// https://github.com/tensorflow/tensorflow/blob/3c520614a3c056d56afdc79b59979b9b0087f8b9/tensorflow/c/tf_tensor.h#L77-L80
|
// dtype - The dtype of the tensor's data.
|
||||||
|
// shape - A shape vector, where each element corresponds to the size of
|
||||||
|
// the tensor's corresponding dimension.
|
||||||
|
// data - Pointer to a buffer of memory to construct a Tensor out of.
|
||||||
|
// len - The length (in bytes) of `data`
|
||||||
|
// deleter - A std::function to be called when the Tensor no longer needs the
|
||||||
|
// memory in `data`. This can be used to free `data`, or
|
||||||
|
// perhaps decrement a refcount associated with `data`, etc.
|
||||||
|
// status - Set to OK on success and an error on failure.
|
||||||
|
// Returns:
|
||||||
|
// If an error occurred, status->ok() will be false, and the returned
|
||||||
|
// Tensor must not be used.
|
||||||
|
// TODO(bmzhao): Add Runtime as an argument to this function so we can swap to
|
||||||
|
// a TFRT backed tensor.
|
||||||
|
// TODO(bmzhao): Add benchmarks on overhead for this function; we can
|
||||||
|
// consider using int64_t* + length rather than vector.
|
||||||
|
static Tensor FromBuffer(TF_DataType dtype, const std::vector<int64_t>& shape,
|
||||||
|
void* data, size_t len, DeleterCallback deleter,
|
||||||
|
Status* status);
|
||||||
|
|
||||||
// TODO(bmzhao): In the case we construct a tensor from non-owned memory,
|
// TODO(bmzhao): In the case we construct a tensor from non-owned memory,
|
||||||
// we should offer a way to deep copy the tensor into a new tensor, which
|
// we should offer a way to deep copy the tensor into a new tensor, which
|
||||||
// owns the underlying memory. This could be a .deepcopy()/clone() method.
|
// owns the underlying memory. This could be a .deepcopy()/clone() method.
|
||||||
|
|
||||||
// TODO(bmzhao): In the future, we want to relax the non-copyability
|
// TODO(bmzhao): In the future, we want to relax the non-copyability
|
||||||
// constraint. To do so, we can add a C API function that acts like CopyFrom:
|
// constraint. To do so, we can add a C API function that acts like
|
||||||
|
// CopyFrom:
|
||||||
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/core/framework/tensor.h#L301-L311
|
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/core/framework/tensor.h#L301-L311
|
||||||
|
|
||||||
// Tensor is movable, but not copyable
|
// Tensor is movable, but not copyable
|
||||||
@ -85,6 +108,16 @@ class Tensor {
|
|||||||
// This object retains ownership of the pointer.
|
// This object retains ownership of the pointer.
|
||||||
TF_Tensor* GetTFTensor() const { return tensor_.get(); }
|
TF_Tensor* GetTFTensor() const { return tensor_.get(); }
|
||||||
|
|
||||||
|
struct DeleterStruct {
|
||||||
|
std::function<void(void*, size_t)> deleter;
|
||||||
|
};
|
||||||
|
|
||||||
|
static void DeleterFunction(void* memory, size_t len, void* deleter_struct) {
|
||||||
|
DeleterStruct* deleter = reinterpret_cast<DeleterStruct*>(deleter_struct);
|
||||||
|
deleter->deleter(memory, len);
|
||||||
|
delete deleter;
|
||||||
|
}
|
||||||
|
|
||||||
struct TFTensorDeleter {
|
struct TFTensorDeleter {
|
||||||
void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); }
|
void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); }
|
||||||
};
|
};
|
||||||
@ -111,7 +144,32 @@ inline size_t Tensor::num_bytes() const {
|
|||||||
return TF_TensorByteSize(tensor_.get());
|
return TF_TensorByteSize(tensor_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline Tensor Tensor::FromBuffer(TF_DataType dtype,
|
||||||
|
const std::vector<int64_t>& shape, void* data,
|
||||||
|
size_t len, DeleterCallback deleter,
|
||||||
|
Status* status) {
|
||||||
|
// Credit to apassos@ for this technique:
|
||||||
|
// Despite the fact that our API takes a std::function deleter, we are able
|
||||||
|
// to maintain ABI stability because:
|
||||||
|
// 1. Only a function pointer is sent across the C API (&DeleterFunction)
|
||||||
|
// 2. DeleterFunction is defined in the same build artifact that constructed
|
||||||
|
// the std::function (so there isn't confusion about std::function ABI).
|
||||||
|
// Note that 2. is satisifed by the fact that this is a header-only API, where
|
||||||
|
// the function implementations are inline.
|
||||||
|
|
||||||
|
DeleterStruct* deleter_struct = new DeleterStruct{deleter};
|
||||||
|
TF_Tensor* tensor = TF_NewTensor(dtype, shape.data(), shape.size(), data, len,
|
||||||
|
&DeleterFunction, deleter_struct);
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
status->SetStatus(TF_INVALID_ARGUMENT,
|
||||||
|
"Failed to create tensor for input buffer");
|
||||||
|
return Tensor(nullptr);
|
||||||
|
}
|
||||||
|
return Tensor(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
|
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
|
||||||
|
98
tensorflow/cc/experimental/base/public/tensorhandle.h
Normal file
98
tensorflow/cc/experimental/base/public/tensorhandle.h
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
||||||
|
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
|
namespace cc {
|
||||||
|
|
||||||
|
// An opaque representation of a tensor computed/managed by the Tensorflow
|
||||||
|
// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer
|
||||||
|
// to tensors placed in memory of different devices or remote address spaces.
|
||||||
|
// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created
|
||||||
|
// from it.
|
||||||
|
class TensorHandle {
|
||||||
|
public:
|
||||||
|
// Unwraps a Tensor from the given TensorHandle. If an error occurred,
|
||||||
|
// status->ok() will be false, and the returned Tensor must not be used.
|
||||||
|
Tensor Resolve(Status* status);
|
||||||
|
|
||||||
|
// Constructs a TensorHandle from a Tensor. If an error occurred,
|
||||||
|
// status->ok() will be false, and the returned TensorHandle must not be used.
|
||||||
|
static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime,
|
||||||
|
Status* status);
|
||||||
|
|
||||||
|
// TensorHandle is movable, and not copyable
|
||||||
|
TensorHandle(TensorHandle&&) = default;
|
||||||
|
TensorHandle& operator=(TensorHandle&&) = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Wraps a TFE_TensorHandle. Takes ownership of handle.
|
||||||
|
explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {}
|
||||||
|
|
||||||
|
// TensorHandle is not copyable
|
||||||
|
TensorHandle(const TensorHandle&) = delete;
|
||||||
|
TensorHandle& operator=(const TensorHandle&) = delete;
|
||||||
|
|
||||||
|
// Returns the underlying TFE_TensorHandle that this object wraps.
|
||||||
|
// This object retains ownership of the pointer.
|
||||||
|
TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); }
|
||||||
|
|
||||||
|
// Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle,
|
||||||
|
// and takes ownership of handle.
|
||||||
|
void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); }
|
||||||
|
|
||||||
|
struct TFETensorHandleDeleter {
|
||||||
|
void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
|
||||||
|
};
|
||||||
|
std::unique_ptr<TFE_TensorHandle, TFETensorHandleDeleter> handle_;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline Tensor TensorHandle::Resolve(Status* status) {
|
||||||
|
TF_Tensor* tensor =
|
||||||
|
TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus());
|
||||||
|
if (!status->ok()) {
|
||||||
|
return Tensor(nullptr);
|
||||||
|
}
|
||||||
|
return Tensor(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor,
|
||||||
|
const Runtime& runtime,
|
||||||
|
Status* status) {
|
||||||
|
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor(
|
||||||
|
runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus());
|
||||||
|
if (!status->ok()) {
|
||||||
|
return TensorHandle(nullptr);
|
||||||
|
}
|
||||||
|
return TensorHandle(tensor_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
50
tensorflow/cc/experimental/base/tests/BUILD
Normal file
50
tensorflow/cc/experimental/base/tests/BUILD
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# Tests for the C++ header-only base types.
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||||
|
|
||||||
|
package(
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensor_types_test_util",
|
||||||
|
testonly = True,
|
||||||
|
hdrs = ["tensor_types_test_util.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c:tf_datatype",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "tensor_test",
|
||||||
|
srcs = [
|
||||||
|
"tensor_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tensor_types_test_util",
|
||||||
|
"//tensorflow/c:tf_datatype",
|
||||||
|
"//tensorflow/cc/experimental/base/public:status",
|
||||||
|
"//tensorflow/cc/experimental/base/public:tensor",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "tensorhandle_test",
|
||||||
|
srcs = [
|
||||||
|
"tensorhandle_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tensor_types_test_util",
|
||||||
|
"//tensorflow/c:tf_datatype",
|
||||||
|
"//tensorflow/cc/experimental/base/public:runtime",
|
||||||
|
"//tensorflow/cc/experimental/base/public:runtime_builder",
|
||||||
|
"//tensorflow/cc/experimental/base/public:status",
|
||||||
|
"//tensorflow/cc/experimental/base/public:tensor",
|
||||||
|
"//tensorflow/cc/experimental/base/public:tensorhandle",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
163
tensorflow/cc/experimental/base/tests/tensor_test.cc
Normal file
163
tensorflow/cc/experimental/base/tests/tensor_test.cc
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using tensorflow::experimental::cc::Status;
|
||||||
|
using tensorflow::experimental::cc::Tensor;
|
||||||
|
|
||||||
|
using SimpleTypes = ::testing::Types<
|
||||||
|
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
|
||||||
|
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
|
||||||
|
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class ConstructScalarTensorTest : public ::testing::Test {};
|
||||||
|
TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes);
|
||||||
|
|
||||||
|
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
|
||||||
|
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||||
|
// number of elements.
|
||||||
|
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||||
|
Status status;
|
||||||
|
TF_DataType dtype = TypeParam::kDType;
|
||||||
|
typename TypeParam::type value = 42;
|
||||||
|
Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||||
|
/*data=*/&value,
|
||||||
|
/*len=*/sizeof(value),
|
||||||
|
/*deleter=*/[](void*, size_t) {}, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.dims(), 0);
|
||||||
|
EXPECT_EQ(tensor.dtype(), dtype);
|
||||||
|
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
|
||||||
|
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
|
||||||
|
EXPECT_EQ(tensor.num_elements(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class Construct1DTensorTest : public ::testing::Test {};
|
||||||
|
TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes);
|
||||||
|
|
||||||
|
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
|
||||||
|
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||||
|
// number of elements.
|
||||||
|
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||||
|
Status status;
|
||||||
|
TF_DataType dtype = TypeParam::kDType;
|
||||||
|
// This is our 1D tensor of varying dtype.
|
||||||
|
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||||
|
// Shape is Rank 1 vector.
|
||||||
|
std::vector<int64_t> shape;
|
||||||
|
shape.push_back(value.size());
|
||||||
|
|
||||||
|
Tensor tensor = Tensor::FromBuffer(
|
||||||
|
/*dtype=*/dtype, /*shape=*/shape,
|
||||||
|
/*data=*/value.data(),
|
||||||
|
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||||
|
/*deleter=*/[](void*, size_t) {}, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.dims(), 1);
|
||||||
|
EXPECT_EQ(tensor.dtype(), dtype);
|
||||||
|
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||||
|
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||||
|
EXPECT_EQ(tensor_view[0], 42);
|
||||||
|
EXPECT_EQ(tensor_view[1], 100);
|
||||||
|
EXPECT_EQ(tensor_view[2], 0);
|
||||||
|
EXPECT_EQ(tensor_view[3], 1);
|
||||||
|
EXPECT_EQ(tensor_view[4], 4);
|
||||||
|
EXPECT_EQ(tensor_view[5], 29);
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.num_bytes(),
|
||||||
|
value.size() * sizeof(typename TypeParam::type));
|
||||||
|
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class Construct2DTensorTest : public ::testing::Test {};
|
||||||
|
TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes);
|
||||||
|
|
||||||
|
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
|
||||||
|
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||||
|
// number of elements.
|
||||||
|
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||||
|
Status status;
|
||||||
|
TF_DataType dtype = TypeParam::kDType;
|
||||||
|
// This is our 1D tensor of varying dtype.
|
||||||
|
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||||
|
// Shape is Rank 2 vector with shape 2 x 3.
|
||||||
|
std::vector<int64_t> shape({2, 3});
|
||||||
|
|
||||||
|
Tensor tensor = Tensor::FromBuffer(
|
||||||
|
/*dtype=*/dtype, /*shape=*/shape,
|
||||||
|
/*data=*/value.data(),
|
||||||
|
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||||
|
/*deleter=*/[](void*, size_t) {}, &status);
|
||||||
|
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.dims(), 2);
|
||||||
|
EXPECT_EQ(tensor.dtype(), dtype);
|
||||||
|
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||||
|
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||||
|
EXPECT_EQ(tensor_view[0], 42);
|
||||||
|
EXPECT_EQ(tensor_view[1], 100);
|
||||||
|
EXPECT_EQ(tensor_view[2], 0);
|
||||||
|
EXPECT_EQ(tensor_view[3], 1);
|
||||||
|
EXPECT_EQ(tensor_view[4], 4);
|
||||||
|
EXPECT_EQ(tensor_view[5], 29);
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.num_bytes(),
|
||||||
|
value.size() * sizeof(typename TypeParam::type));
|
||||||
|
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
|
||||||
|
bool done = false;
|
||||||
|
Status status;
|
||||||
|
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
|
||||||
|
{
|
||||||
|
// data_vector is a rank 1 tensor.
|
||||||
|
std::vector<int64_t> shape;
|
||||||
|
shape.push_back(data_vector.size());
|
||||||
|
|
||||||
|
Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
|
||||||
|
done = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
Tensor tensor =
|
||||||
|
Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
|
||||||
|
/*data=*/data_vector.data(),
|
||||||
|
/*len=*/data_vector.size() * sizeof(int32_t),
|
||||||
|
/*deleter=*/callback, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
}
|
||||||
|
// At this point, tensor has been destroyed, and the deleter callback should
|
||||||
|
// have run.
|
||||||
|
EXPECT_TRUE(done);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
@ -0,0 +1,76 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
||||||
|
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Each of the following struct types have two members: a kDType that
|
||||||
|
// corresponds to a TF_Datatype enum value, and a typedef "type"
|
||||||
|
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
|
||||||
|
// tests via GoogleTest's TypedTests:
|
||||||
|
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
|
||||||
|
struct FloatType {
|
||||||
|
using type = float;
|
||||||
|
static constexpr TF_DataType kDType = TF_FLOAT;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DoubleType {
|
||||||
|
using type = double;
|
||||||
|
static constexpr TF_DataType kDType = TF_DOUBLE;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Int32Type {
|
||||||
|
using type = int32_t;
|
||||||
|
static constexpr TF_DataType kDType = TF_INT32;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct UINT8Type {
|
||||||
|
using type = uint8_t;
|
||||||
|
static constexpr TF_DataType kDType = TF_UINT8;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct INT8Type {
|
||||||
|
using type = int8_t;
|
||||||
|
static constexpr TF_DataType kDType = TF_INT8;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct INT64Type {
|
||||||
|
using type = int64_t;
|
||||||
|
static constexpr TF_DataType kDType = TF_INT64;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct UINT16Type {
|
||||||
|
using type = uint16_t;
|
||||||
|
static constexpr TF_DataType kDType = TF_UINT16;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct UINT32Type {
|
||||||
|
using type = uint32_t;
|
||||||
|
static constexpr TF_DataType kDType = TF_UINT32;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct UINT64Type {
|
||||||
|
using type = uint64_t;
|
||||||
|
static constexpr TF_DataType kDType = TF_UINT64;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
184
tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
Normal file
184
tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/experimental/base/public/tensorhandle.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||||
|
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using tensorflow::experimental::cc::Runtime;
|
||||||
|
using tensorflow::experimental::cc::RuntimeBuilder;
|
||||||
|
using tensorflow::experimental::cc::Status;
|
||||||
|
using tensorflow::experimental::cc::Tensor;
|
||||||
|
using tensorflow::experimental::cc::TensorHandle;
|
||||||
|
|
||||||
|
using SimpleTypes = ::testing::Types<
|
||||||
|
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
|
||||||
|
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
|
||||||
|
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class ConstructScalarTensorHandleTest : public ::testing::Test {};
|
||||||
|
TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
|
||||||
|
|
||||||
|
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
|
||||||
|
// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
|
||||||
|
// verify the expected dims, dtype, value, num bytes, and num elements.
|
||||||
|
TYPED_TEST(ConstructScalarTensorHandleTest,
|
||||||
|
ValidTensorAttributesAfterConstruction) {
|
||||||
|
Status status;
|
||||||
|
RuntimeBuilder runtime_builder;
|
||||||
|
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
TF_DataType dtype = TypeParam::kDType;
|
||||||
|
typename TypeParam::type value = 42;
|
||||||
|
Tensor original_tensor =
|
||||||
|
Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||||
|
/*data=*/&value,
|
||||||
|
/*len=*/sizeof(value),
|
||||||
|
/*deleter=*/[](void*, size_t) {}, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
TensorHandle handle =
|
||||||
|
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
Tensor tensor = handle.Resolve(&status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.dims(), 0);
|
||||||
|
EXPECT_EQ(tensor.dtype(), dtype);
|
||||||
|
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
|
||||||
|
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
|
||||||
|
EXPECT_EQ(tensor.num_elements(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class Construct1DTensorHandleTest : public ::testing::Test {};
|
||||||
|
TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
|
||||||
|
|
||||||
|
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
|
||||||
|
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||||
|
// number of elements.
|
||||||
|
TYPED_TEST(Construct1DTensorHandleTest,
|
||||||
|
ValidTensorAttributesAfterConstruction) {
|
||||||
|
Status status;
|
||||||
|
RuntimeBuilder runtime_builder;
|
||||||
|
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
TF_DataType dtype = TypeParam::kDType;
|
||||||
|
// This is our 1D tensor of varying dtype.
|
||||||
|
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||||
|
// Shape is Rank 1 vector.
|
||||||
|
std::vector<int64_t> shape;
|
||||||
|
shape.push_back(value.size());
|
||||||
|
|
||||||
|
Tensor original_tensor = Tensor::FromBuffer(
|
||||||
|
/*dtype=*/dtype, /*shape=*/shape,
|
||||||
|
/*data=*/value.data(),
|
||||||
|
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||||
|
/*deleter=*/[](void*, size_t) {}, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
TensorHandle handle =
|
||||||
|
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
Tensor tensor = handle.Resolve(&status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.dims(), 1);
|
||||||
|
EXPECT_EQ(tensor.dtype(), dtype);
|
||||||
|
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||||
|
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||||
|
EXPECT_EQ(tensor_view[0], 42);
|
||||||
|
EXPECT_EQ(tensor_view[1], 100);
|
||||||
|
EXPECT_EQ(tensor_view[2], 0);
|
||||||
|
EXPECT_EQ(tensor_view[3], 1);
|
||||||
|
EXPECT_EQ(tensor_view[4], 4);
|
||||||
|
EXPECT_EQ(tensor_view[5], 29);
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.num_bytes(),
|
||||||
|
value.size() * sizeof(typename TypeParam::type));
|
||||||
|
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class Construct2DTensorHandleTest : public ::testing::Test {};
|
||||||
|
TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
|
||||||
|
|
||||||
|
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
|
||||||
|
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||||
|
// number of elements.
|
||||||
|
TYPED_TEST(Construct2DTensorHandleTest,
|
||||||
|
ValidTensorAttributesAfterConstruction) {
|
||||||
|
Status status;
|
||||||
|
RuntimeBuilder runtime_builder;
|
||||||
|
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
TF_DataType dtype = TypeParam::kDType;
|
||||||
|
// This is our 1D tensor of varying dtype.
|
||||||
|
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||||
|
// Shape is Rank 2 vector with shape 2 x 3.
|
||||||
|
std::vector<int64_t> shape({2, 3});
|
||||||
|
|
||||||
|
Tensor original_tensor = Tensor::FromBuffer(
|
||||||
|
/*dtype=*/dtype, /*shape=*/shape,
|
||||||
|
/*data=*/value.data(),
|
||||||
|
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||||
|
/*deleter=*/[](void*, size_t) {}, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
TensorHandle handle =
|
||||||
|
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
Tensor tensor = handle.Resolve(&status);
|
||||||
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.dims(), 2);
|
||||||
|
EXPECT_EQ(tensor.dtype(), dtype);
|
||||||
|
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||||
|
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||||
|
EXPECT_EQ(tensor_view[0], 42);
|
||||||
|
EXPECT_EQ(tensor_view[1], 100);
|
||||||
|
EXPECT_EQ(tensor_view[2], 0);
|
||||||
|
EXPECT_EQ(tensor_view[3], 1);
|
||||||
|
EXPECT_EQ(tensor_view[4], 4);
|
||||||
|
EXPECT_EQ(tensor_view[5], 29);
|
||||||
|
|
||||||
|
EXPECT_EQ(tensor.num_bytes(),
|
||||||
|
value.size() * sizeof(typename TypeParam::type));
|
||||||
|
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -4,7 +4,6 @@
|
|||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
"if_android",
|
"if_android",
|
||||||
"if_ios",
|
|
||||||
"if_mobile",
|
"if_mobile",
|
||||||
"if_not_mobile",
|
"if_not_mobile",
|
||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
|
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
|
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
|
||||||
@ -54,6 +55,7 @@ inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
|
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
|
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// ConcreteFunctionList helps convert an opaque pointer to an array of
|
// ConcreteFunctionList helps convert an opaque pointer to an array of
|
||||||
@ -56,6 +57,7 @@ inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// FunctionMetadata stores additional function information, including
|
// FunctionMetadata stores additional function information, including
|
||||||
@ -40,6 +41,7 @@ class FunctionMetadata final {
|
|||||||
};
|
};
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
|
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
|
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
namespace experimental {
|
||||||
namespace cc {
|
namespace cc {
|
||||||
|
|
||||||
// SavedModelAPI offers a way to load Tensorflow Saved Models
|
// SavedModelAPI offers a way to load Tensorflow Saved Models
|
||||||
@ -155,6 +156,7 @@ inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cc
|
} // namespace cc
|
||||||
|
} // namespace experimental
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
|
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
|
||||||
|
@ -26,10 +26,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/stringpiece.h"
|
#include "tensorflow/core/platform/stringpiece.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using tensorflow::experimental::cc::Runtime;
|
||||||
|
using tensorflow::experimental::cc::RuntimeBuilder;
|
||||||
|
using tensorflow::experimental::cc::SavedModelAPI;
|
||||||
|
using tensorflow::experimental::cc::Status;
|
||||||
|
|
||||||
constexpr char kTestData[] = "cc/saved_model/testdata";
|
constexpr char kTestData[] = "cc/saved_model/testdata";
|
||||||
|
|
||||||
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||||
@ -43,21 +47,21 @@ std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
|||||||
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
|
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
|
||||||
|
|
||||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||||
cc::Status status;
|
Status status;
|
||||||
cc::RuntimeBuilder builder;
|
RuntimeBuilder builder;
|
||||||
bool use_tfrt = GetParam();
|
bool use_tfrt = GetParam();
|
||||||
if (use_tfrt) {
|
if (use_tfrt) {
|
||||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||||
}
|
}
|
||||||
|
|
||||||
builder.SetUseTFRT(use_tfrt);
|
builder.SetUseTFRT(use_tfrt);
|
||||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
std::unique_ptr<Runtime> runtime = builder.Build(&status);
|
||||||
ASSERT_TRUE(status.ok()) << status.message();
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||||
std::unordered_set<std::string> tags = {"serve"};
|
std::unordered_set<std::string> tags = {"serve"};
|
||||||
std::unique_ptr<cc::SavedModelAPI> model =
|
std::unique_ptr<SavedModelAPI> model =
|
||||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
|
SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
|
||||||
|
|
||||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||||
@ -67,20 +71,20 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
|
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
|
||||||
cc::Status status;
|
Status status;
|
||||||
cc::RuntimeBuilder builder;
|
RuntimeBuilder builder;
|
||||||
bool use_tfrt = GetParam();
|
bool use_tfrt = GetParam();
|
||||||
if (use_tfrt) {
|
if (use_tfrt) {
|
||||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||||
}
|
}
|
||||||
|
|
||||||
builder.SetUseTFRT(use_tfrt);
|
builder.SetUseTFRT(use_tfrt);
|
||||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
std::unique_ptr<Runtime> runtime = builder.Build(&status);
|
||||||
ASSERT_TRUE(status.ok()) << status.message();
|
ASSERT_TRUE(status.ok()) << status.message();
|
||||||
|
|
||||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||||
std::unique_ptr<cc::SavedModelAPI> model =
|
std::unique_ptr<SavedModelAPI> model =
|
||||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status);
|
SavedModelAPI::Load(model_dir, *runtime, &status);
|
||||||
|
|
||||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||||
@ -94,4 +98,3 @@ INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
@ -131,6 +131,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
|||||||
TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
|
TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
|
||||||
std::vector<string> dim_vars;
|
std::vector<string> dim_vars;
|
||||||
string dim_sizes, indices;
|
string dim_sizes, indices;
|
||||||
|
int count = 1;
|
||||||
if (shape.rank() == 0 ||
|
if (shape.rank() == 0 ||
|
||||||
(shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
|
(shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
|
||||||
dim_sizes = "[1]";
|
dim_sizes = "[1]";
|
||||||
@ -140,6 +141,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
|||||||
dim_vars.push_back(absl::StrCat("size_t dim", dim));
|
dim_vars.push_back(absl::StrCat("size_t dim", dim));
|
||||||
dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
|
dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
|
||||||
indices += absl::StrCat("[dim", dim, "]");
|
indices += absl::StrCat("[dim", dim, "]");
|
||||||
|
count *= shape.dimensions(dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rewrites->push_back({"{{I}}", absl::StrCat(i)});
|
rewrites->push_back({"{{I}}", absl::StrCat(i)});
|
||||||
@ -147,6 +149,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
|||||||
rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
|
rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
|
||||||
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
|
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
|
||||||
rewrites->push_back({"{{INDICES}}", indices});
|
rewrites->push_back({"{{INDICES}}", indices});
|
||||||
|
rewrites->push_back({"{{COUNT}}", absl::StrCat(count)});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,6 +202,12 @@ Status GenArgMethods(const tf2xla::Config& config,
|
|||||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||||
arg_data({{I}}))){{INDICES}};
|
arg_data({{I}}))){{INDICES}};
|
||||||
}
|
}
|
||||||
|
int arg{{NAME}}_size() const {
|
||||||
|
return {{COUNT}} * sizeof({{TYPE}});
|
||||||
|
}
|
||||||
|
int arg{{NAME}}_count() const {
|
||||||
|
return {{COUNT}};
|
||||||
|
}
|
||||||
)";
|
)";
|
||||||
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
||||||
if (!config.feed(i).name().empty()) {
|
if (!config.feed(i).name().empty()) {
|
||||||
@ -246,6 +255,12 @@ Status GenResultMethods(const tf2xla::Config& config,
|
|||||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||||
result_data({{I}}))){{INDICES}};
|
result_data({{I}}))){{INDICES}};
|
||||||
}
|
}
|
||||||
|
int result{{NAME}}_size() const {
|
||||||
|
return {{COUNT}} * sizeof({{TYPE}});
|
||||||
|
}
|
||||||
|
int result{{NAME}}_count() const {
|
||||||
|
return {{COUNT}};
|
||||||
|
}
|
||||||
)";
|
)";
|
||||||
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
||||||
if (!config.fetch(i).name().empty()) {
|
if (!config.fetch(i).name().empty()) {
|
||||||
@ -281,6 +296,12 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
|||||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||||
arg_data({{I}}))){{INDICES}};
|
arg_data({{I}}))){{INDICES}};
|
||||||
}
|
}
|
||||||
|
int var_{{NAME}}_size() const {
|
||||||
|
return {{COUNT}} * sizeof({{TYPE}});
|
||||||
|
}
|
||||||
|
int var_{{NAME}}_count() const {
|
||||||
|
return {{COUNT}};
|
||||||
|
}
|
||||||
)";
|
)";
|
||||||
const tf2xla::Variable& var = config.variable(i - config.feed_size());
|
const tf2xla::Variable& var = config.variable(i - config.feed_size());
|
||||||
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
|
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
|
||||||
|
@ -138,6 +138,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const float(*)[1][2]>(
|
return (*static_cast<const float(*)[1][2]>(
|
||||||
arg_data(0)))[dim0][dim1];
|
arg_data(0)))[dim0][dim1];
|
||||||
}
|
}
|
||||||
|
int arg0_size() const {
|
||||||
|
return 2 * sizeof(float);
|
||||||
|
}
|
||||||
|
int arg0_count() const {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
void set_arg_myfeed_data(const void* data) {
|
void set_arg_myfeed_data(const void* data) {
|
||||||
set_arg_data(0, data);
|
set_arg_data(0, data);
|
||||||
@ -156,6 +162,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const float(*)[1][2]>(
|
return (*static_cast<const float(*)[1][2]>(
|
||||||
arg_data(0)))[dim0][dim1];
|
arg_data(0)))[dim0][dim1];
|
||||||
}
|
}
|
||||||
|
int arg_myfeed_size() const {
|
||||||
|
return 2 * sizeof(float);
|
||||||
|
}
|
||||||
|
int arg_myfeed_count() const {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
void set_arg1_data(const void* data) {
|
void set_arg1_data(const void* data) {
|
||||||
set_arg_data(1, data);
|
set_arg_data(1, data);
|
||||||
@ -174,6 +186,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const tensorflow::int64(*)[3][4]>(
|
return (*static_cast<const tensorflow::int64(*)[3][4]>(
|
||||||
arg_data(1)))[dim0][dim1];
|
arg_data(1)))[dim0][dim1];
|
||||||
}
|
}
|
||||||
|
int arg1_size() const {
|
||||||
|
return 12 * sizeof(tensorflow::int64);
|
||||||
|
}
|
||||||
|
int arg1_count() const {
|
||||||
|
return 12;
|
||||||
|
}
|
||||||
|
|
||||||
// Result methods for managing output buffers. Buffers are in row-major order.
|
// Result methods for managing output buffers. Buffers are in row-major order.
|
||||||
// Must only be called after a successful Run call. There is a set of methods
|
// Must only be called after a successful Run call. There is a set of methods
|
||||||
@ -204,6 +222,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
||||||
result_data(0)))[dim0][dim1];
|
result_data(0)))[dim0][dim1];
|
||||||
}
|
}
|
||||||
|
int result0_size() const {
|
||||||
|
return 30 * sizeof(tensorflow::uint32);
|
||||||
|
}
|
||||||
|
int result0_count() const {
|
||||||
|
return 30;
|
||||||
|
}
|
||||||
|
|
||||||
tensorflow::uint32* result_myfetch_data() {
|
tensorflow::uint32* result_myfetch_data() {
|
||||||
return static_cast<tensorflow::uint32*>(result_data(0));
|
return static_cast<tensorflow::uint32*>(result_data(0));
|
||||||
@ -219,6 +243,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
||||||
result_data(0)))[dim0][dim1];
|
result_data(0)))[dim0][dim1];
|
||||||
}
|
}
|
||||||
|
int result_myfetch_size() const {
|
||||||
|
return 30 * sizeof(tensorflow::uint32);
|
||||||
|
}
|
||||||
|
int result_myfetch_count() const {
|
||||||
|
return 30;
|
||||||
|
}
|
||||||
|
|
||||||
// Methods for managing variable buffers. Buffers are in row-major order.
|
// Methods for managing variable buffers. Buffers are in row-major order.
|
||||||
//
|
//
|
||||||
@ -261,6 +291,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const float(*)[1]>(
|
return (*static_cast<const float(*)[1]>(
|
||||||
arg_data(2)))[0];
|
arg_data(2)))[0];
|
||||||
}
|
}
|
||||||
|
int var_myvar_readonly_size() const {
|
||||||
|
return 1 * sizeof(float);
|
||||||
|
}
|
||||||
|
int var_myvar_readonly_count() const {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
void set_var_myvar_data(float* data) {
|
void set_var_myvar_data(float* data) {
|
||||||
set_arg_data(3, data);
|
set_arg_data(3, data);
|
||||||
@ -279,6 +315,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const float(*)[1]>(
|
return (*static_cast<const float(*)[1]>(
|
||||||
arg_data(3)))[0];
|
arg_data(3)))[0];
|
||||||
}
|
}
|
||||||
|
int var_myvar_size() const {
|
||||||
|
return 1 * sizeof(float);
|
||||||
|
}
|
||||||
|
int var_myvar_count() const {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
void set_var_myvar2_data(tensorflow::int32* data) {
|
void set_var_myvar2_data(tensorflow::int32* data) {
|
||||||
set_arg_data(4, data);
|
set_arg_data(4, data);
|
||||||
@ -297,6 +339,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
return (*static_cast<const tensorflow::int32(*)[5]>(
|
return (*static_cast<const tensorflow::int32(*)[5]>(
|
||||||
arg_data(4)))[dim0];
|
arg_data(4)))[dim0];
|
||||||
}
|
}
|
||||||
|
int var_myvar2_size() const {
|
||||||
|
return 5 * sizeof(tensorflow::int32);
|
||||||
|
}
|
||||||
|
int var_myvar2_count() const {
|
||||||
|
return 5;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Number of buffers for the compiled computation.
|
// Number of buffers for the compiled computation.
|
||||||
|
@ -180,11 +180,9 @@ class XlaAssignVariableOp : public OpKernel {
|
|||||||
data::MakeIteratorOp); \
|
data::MakeIteratorOp); \
|
||||||
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \
|
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \
|
||||||
data::AnonymousIteratorHandleOp); \
|
data::AnonymousIteratorHandleOp); \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV2").Device(DEVICE), \
|
||||||
Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \
|
|
||||||
data::AnonymousIteratorHandleOp); \
|
data::AnonymousIteratorHandleOp); \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE), \
|
||||||
Name("DeleteIterator").Device(DEVICE).HostMemory("deleter"), \
|
|
||||||
data::DeleteIteratorOp); \
|
data::DeleteIteratorOp); \
|
||||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
|
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
|
||||||
data::IteratorGetNextOp); \
|
data::IteratorGetNextOp); \
|
||||||
|
@ -31,7 +31,7 @@ filegroup(
|
|||||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||||
"@llvm-project//mlir:OpBaseTdFiles",
|
"@llvm-project//mlir:OpBaseTdFiles",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -695,9 +695,9 @@ cc_library(
|
|||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:LoopOpsTransforms",
|
|
||||||
"@llvm-project//mlir:MlirTranslateMain",
|
"@llvm-project//mlir:MlirTranslateMain",
|
||||||
"@llvm-project//mlir:QuantOps",
|
"@llvm-project//mlir:QuantOps",
|
||||||
|
"@llvm-project//mlir:SCFTransforms",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir:Translation",
|
"@llvm-project//mlir:Translation",
|
||||||
|
@ -1020,7 +1020,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
|||||||
if (!inst->getMutableAttrDict().getAttrs().empty()) {
|
if (!inst->getMutableAttrDict().getAttrs().empty()) {
|
||||||
os << " {";
|
os << " {";
|
||||||
bool first = true;
|
bool first = true;
|
||||||
for (auto& named_attr : inst->getMutableAttrDict().getDictionary()) {
|
for (auto& named_attr : inst->getAttrDictionary()) {
|
||||||
os << (!first ? ", " : "");
|
os << (!first ? ", " : "");
|
||||||
first = false;
|
first = false;
|
||||||
named_attr.first.print(os);
|
named_attr.first.print(os);
|
||||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Interfaces/LoopLikeInterface.td"
|
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||||
include "mlir/Interfaces/SideEffects.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
|
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
|
||||||
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
|
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
|
||||||
|
|
||||||
@ -247,7 +247,14 @@ class TFL_TFTypesWithSameBits<int i, int j, int num> :
|
|||||||
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||||
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
|
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
|
||||||
|
|
||||||
class TFL_OperandIsNoneOrHasRankLessThanOrEqualTo<int n, int m> :
|
class TFL_TFOperandTypesWithSameBits<int i, int j, int num> :
|
||||||
|
And<[
|
||||||
|
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||||
|
CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isUnsignedInteger(" # num # ")">]>,
|
||||||
|
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||||
|
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
|
||||||
|
|
||||||
|
class TFL_OperandIsNoneOrHasRankAtMost<int n, int m> :
|
||||||
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
||||||
Or<[
|
Or<[
|
||||||
CPred<"$_op.getOperand(" # n # ").getType().isa<NoneType>()">,
|
CPred<"$_op.getOperand(" # n # ").getType().isa<NoneType>()">,
|
||||||
@ -255,13 +262,13 @@ class TFL_OperandIsNoneOrHasRankLessThanOrEqualTo<int n, int m> :
|
|||||||
CPred<"$_op.getOperand(" # n #
|
CPred<"$_op.getOperand(" # n #
|
||||||
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
||||||
|
|
||||||
class TFL_OperandHasRankLessThanOrEqualTo<int n, int m> :
|
class TFL_OperandHasRankAtMost<int n, int m> :
|
||||||
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
||||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||||
CPred<"$_op.getOperand(" # n #
|
CPred<"$_op.getOperand(" # n #
|
||||||
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
||||||
|
|
||||||
class TFL_OperandHasRankGreaterThanOrEqualTo<int n, int m> :
|
class TFL_OperandHasRankAtLeast<int n, int m> :
|
||||||
PredOpTrait<"operand " # n # " is at least " # m # "-D",
|
PredOpTrait<"operand " # n # " is at least " # m # "-D",
|
||||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||||
CPred<"$_op.getOperand(" # n #
|
CPred<"$_op.getOperand(" # n #
|
||||||
@ -300,6 +307,18 @@ class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
|
|||||||
"quant::QuantizedType::castToStorageType("
|
"quant::QuantizedType::castToStorageType("
|
||||||
"getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>;
|
"getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>;
|
||||||
|
|
||||||
|
// This is a quantization-aware version of TCresVTEtIsSameAsOp
|
||||||
|
class TFL_TCopVTEtAreSameAt<int i, int j> : Or<[
|
||||||
|
TCopVTEtAreSameAt<[i, j]>,
|
||||||
|
TFL_TFOperandTypesWithSameBits<i, j, 8>,
|
||||||
|
And<[
|
||||||
|
SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # j # "))",
|
||||||
|
quant_QuantizedType.predicate>,
|
||||||
|
CPred<"quant::QuantizedType::castToStorageType("
|
||||||
|
"getElementTypeOrSelf($_op.getOperand(" # i # "))) == "
|
||||||
|
"quant::QuantizedType::castToStorageType("
|
||||||
|
"getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TFL op common constraints.
|
// TFL op common constraints.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -963,7 +982,11 @@ def TFL_ScatterNdOp : TFL_Op<"scatter_nd", [
|
|||||||
|
|
||||||
// Same type check of lhs and rhs is handled by the ResultsBroadcastableShape trait.
|
// Same type check of lhs and rhs is handled by the ResultsBroadcastableShape trait.
|
||||||
def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
||||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
ResultsBroadcastableShape,
|
||||||
|
BinaryOpSameElementTypeConstraint,
|
||||||
|
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||||
|
NoSideEffect,
|
||||||
|
NoQuantizableResult]> {
|
||||||
let summary = "Less_equal operator";
|
let summary = "Less_equal operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -971,8 +994,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$lhs,
|
ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8]>:$lhs,
|
||||||
TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$rhs);
|
TFL_TensorOf<[F32, I32, I64, QI8, QUI8]>:$rhs);
|
||||||
|
|
||||||
let results = (outs TFL_BoolTensor:$output);
|
let results = (outs TFL_BoolTensor:$output);
|
||||||
|
|
||||||
@ -985,8 +1008,11 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
|||||||
let hasOptions = 0;
|
let hasOptions = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization",
|
def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization", [
|
||||||
[NoSideEffect]> {
|
TFL_OperandHasRank<0, 4>,
|
||||||
|
SameOperandsAndResultShape,
|
||||||
|
SameOperandsAndResultType,
|
||||||
|
NoSideEffect]> {
|
||||||
let summary = "Local Response Normalization.";
|
let summary = "Local Response Normalization.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -1004,7 +1030,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, QI8, QUI8]>:$input,
|
TFL_FpTensor:$input,
|
||||||
I32Attr:$radius,
|
I32Attr:$radius,
|
||||||
F32Attr:$bias,
|
F32Attr:$bias,
|
||||||
F32Attr:$alpha,
|
F32Attr:$alpha,
|
||||||
@ -1012,7 +1038,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
|
|||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, QI8, QUI8]>:$output
|
TFL_FpTensor:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
@ -1048,7 +1074,7 @@ def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [
|
|||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
TFL_OperandHasAtleastRank<0, 1>,
|
TFL_OperandHasAtleastRank<0, 1>,
|
||||||
PredOpTrait<"operand and result must have the same element type",
|
PredOpTrait<"operand and result must have the same element type",
|
||||||
TCresVTEtIsSameAsOp<0, 0>>]> {
|
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Returns a tensor with the provided diagonal and everything else padded with zeros.
|
Returns a tensor with the provided diagonal and everything else padded with zeros.
|
||||||
}];
|
}];
|
||||||
@ -1061,17 +1087,21 @@ def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$diagonal
|
TFL_TensorOf<[F32, I8, I16, I32, I64, TFL_Uint8, QUI8, QI8, TFL_Quint8]>:$diagonal
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output
|
TFL_TensorOf<[F32, I8, I16, I32, I64, TFL_Uint8, QUI8, QI8, TFL_Quint8]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasOptions = 0;
|
let hasOptions = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_MatrixSetDiagOp : TFL_Op<"matrix_set_diag", [NoSideEffect]> {
|
def TFL_MatrixSetDiagOp : TFL_Op<"matrix_set_diag", [
|
||||||
|
TFL_OperandHasAtleastRank<0, 2>,
|
||||||
|
PredOpTrait<"input and result must have the same element type",
|
||||||
|
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||||
|
NoSideEffect]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Returns a batched matrix tensor with new batched diagonal values.
|
Returns a batched matrix tensor with new batched diagonal values.
|
||||||
}];
|
}];
|
||||||
@ -1083,12 +1113,12 @@ innermost matrices. These will be overwritten by the values in `diagonal`.
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$input,
|
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$input,
|
||||||
TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$diagonal
|
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$diagonal
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$output
|
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$result
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasOptions = 0;
|
let hasOptions = 0;
|
||||||
@ -1206,7 +1236,12 @@ larger than 0.
|
|||||||
}
|
}
|
||||||
|
|
||||||
def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||||
ResultsBroadcastableShape, Commutative, NoSideEffect, NoQuantizableResult]> {
|
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||||
|
BinaryOpSameElementTypeConstraint,
|
||||||
|
ResultsBroadcastableShape,
|
||||||
|
Commutative,
|
||||||
|
NoSideEffect,
|
||||||
|
NoQuantizableResult]> {
|
||||||
let summary = "Not_equal operator";
|
let summary = "Not_equal operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -1214,8 +1249,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins AnyTensor:$lhs,
|
ins TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$lhs,
|
||||||
AnyTensor:$rhs);
|
TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$rhs);
|
||||||
|
|
||||||
let results = (outs TFL_BoolTensor:$output);
|
let results = (outs TFL_BoolTensor:$output);
|
||||||
|
|
||||||
@ -1284,7 +1319,7 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
|
|||||||
PredOpTrait<"value and output must have same element type",
|
PredOpTrait<"value and output must have same element type",
|
||||||
TFL_TCresVTEtIsSameAsOp<0, 1>>,
|
TFL_TCresVTEtIsSameAsOp<0, 1>>,
|
||||||
TFL_OperandHasRank<0, 1>,
|
TFL_OperandHasRank<0, 1>,
|
||||||
TFL_OperandHasRankGreaterThanOrEqualTo<1, 2>
|
TFL_OperandHasRankAtLeast<1, 2>
|
||||||
]> {
|
]> {
|
||||||
let summary = "Embedding lookup operator";
|
let summary = "Embedding lookup operator";
|
||||||
|
|
||||||
@ -1502,7 +1537,11 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [
|
|||||||
}
|
}
|
||||||
|
|
||||||
def TFL_GreaterOp : TFL_Op<"greater", [
|
def TFL_GreaterOp : TFL_Op<"greater", [
|
||||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
ResultsBroadcastableShape,
|
||||||
|
BinaryOpSameElementTypeConstraint,
|
||||||
|
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||||
|
NoSideEffect,
|
||||||
|
NoQuantizableResult]> {
|
||||||
let summary = "Greater operator";
|
let summary = "Greater operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -1510,10 +1549,10 @@ def TFL_GreaterOp : TFL_Op<"greater", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins AnyTensor:$lhs,
|
ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$lhs,
|
||||||
AnyTensor:$rhs);
|
TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$rhs);
|
||||||
|
|
||||||
let results = (outs AnyTensor:$output);
|
let results = (outs TFL_BoolTensor:$output);
|
||||||
|
|
||||||
let builders = [TFL_ComparisonBinaryBuilder];
|
let builders = [TFL_ComparisonBinaryBuilder];
|
||||||
|
|
||||||
@ -1524,6 +1563,7 @@ def TFL_GreaterOp : TFL_Op<"greater", [
|
|||||||
|
|
||||||
def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect,
|
def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect,
|
||||||
SameOperandsAndResultShape,
|
SameOperandsAndResultShape,
|
||||||
|
SameOperandsAndResultType,
|
||||||
TFL_GpuTargetOp]> {
|
TFL_GpuTargetOp]> {
|
||||||
let summary = "Hardswish activation function.";
|
let summary = "Hardswish activation function.";
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -1563,29 +1603,34 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect,
|
|||||||
let customOption = "L2NormOptions";
|
let customOption = "L2NormOptions";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultType]> {
|
def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
|
||||||
|
SameOperandsAndResultShape,
|
||||||
|
NoSideEffect,
|
||||||
|
SameOperandsAndResultType]> {
|
||||||
let summary = "Leaky Relu operator";
|
let summary = "Leaky Relu operator";
|
||||||
|
|
||||||
// TODO(jpienaar): Add type restriction. This op is only defined for
|
|
||||||
// restricted (floating point) types.
|
|
||||||
let description = [{
|
let description = [{
|
||||||
Element-wise Leaky ReLU operator
|
Element-wise Leaky ReLU operator
|
||||||
x -> x >= 0 ? x : (alpha * x)
|
x -> x >= 0 ? x : (alpha * x)
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins AnyTensor:$input,
|
ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input,
|
||||||
// Slope of the activation function at x < 0.
|
// Slope of the activation function at x < 0.
|
||||||
F32Attr:$alpha
|
F32Attr:$alpha
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs AnyTensor:$output);
|
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output);
|
||||||
|
|
||||||
let hasOptions = 0b1;
|
let hasOptions = 0b1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_LessOp : TFL_Op<"less", [
|
def TFL_LessOp : TFL_Op<"less", [
|
||||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
ResultsBroadcastableShape,
|
||||||
|
BinaryOpSameElementTypeConstraint,
|
||||||
|
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||||
|
NoSideEffect,
|
||||||
|
NoQuantizableResult]> {
|
||||||
let summary = "Less operator";
|
let summary = "Less operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -1593,8 +1638,8 @@ def TFL_LessOp : TFL_Op<"less", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins AnyTensor:$lhs,
|
ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$lhs,
|
||||||
AnyTensor:$rhs);
|
TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$rhs);
|
||||||
|
|
||||||
let results = (outs TFL_BoolTensor:$output);
|
let results = (outs TFL_BoolTensor:$output);
|
||||||
|
|
||||||
@ -1655,6 +1700,8 @@ def TFL_LogicalOrOp : TFL_Op<"logical_or", [NoSideEffect]> {
|
|||||||
|
|
||||||
def TFL_LogisticOp: TFL_Op<"logistic", [
|
def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
|
PredOpTrait<"x and y must have same element type",
|
||||||
|
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||||
SameOperandsAndResultShape,
|
SameOperandsAndResultShape,
|
||||||
// zero_point = 0
|
// zero_point = 0
|
||||||
// scale = 1. / (max_value + 1)
|
// scale = 1. / (max_value + 1)
|
||||||
@ -1667,9 +1714,9 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
|||||||
Computes element-wise Sigmoid of input
|
Computes element-wise Sigmoid of input
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$x);
|
let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$x);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$y);
|
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$y);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_LogOp: TFL_Op<"log", [
|
def TFL_LogOp: TFL_Op<"log", [
|
||||||
@ -1690,10 +1737,10 @@ def TFL_LogOp: TFL_Op<"log", [
|
|||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(b/130643170): Adds some constraint for the input/output element types.
|
|
||||||
def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
|
def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
SameOperandsAndResultShape,
|
SameOperandsAndResultShape,
|
||||||
|
SameOperandsAndResultType,
|
||||||
// zero_point = max_value
|
// zero_point = max_value
|
||||||
// scale = -log_softmax_output_min / (max_value + 1)
|
// scale = -log_softmax_output_min / (max_value + 1)
|
||||||
FixedResultScale<Int8UniformQuantizedType<127, 625, -4>>,
|
FixedResultScale<Int8UniformQuantizedType<127, 625, -4>>,
|
||||||
@ -1706,9 +1753,9 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
|
|||||||
input - log(reduce_sum(exp(input), dim))
|
input - log(reduce_sum(exp(input), dim))
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins AnyTensor:$input);
|
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input);
|
||||||
|
|
||||||
let results = (outs AnyTensor:$output);
|
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
}
|
}
|
||||||
@ -1727,6 +1774,9 @@ def MaxPoolOperandAndResultConstraints : PredOpTrait<"MaxPool2D operand and "
|
|||||||
TFL_TCresVTEtIsSameAsOp<0, 0>]>>;
|
TFL_TCresVTEtIsSameAsOp<0, 0>]>>;
|
||||||
|
|
||||||
def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
||||||
|
TFL_OperandHasRank<0, 4>,
|
||||||
|
PredOpTrait<"input and output must have same element type",
|
||||||
|
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
MaxPoolOperandAndResultConstraints,
|
MaxPoolOperandAndResultConstraints,
|
||||||
SameOperandsAndResultsScale,
|
SameOperandsAndResultsScale,
|
||||||
@ -1741,7 +1791,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins AnyTensor:$input,
|
ins TFL_TensorOf<[F32, QUI8, QI8, QI16, TFL_Quint8]>:$input,
|
||||||
TFL_PaddingAttr:$padding,
|
TFL_PaddingAttr:$padding,
|
||||||
I32Attr:$stride_w,
|
I32Attr:$stride_w,
|
||||||
I32Attr:$stride_h,
|
I32Attr:$stride_h,
|
||||||
@ -1750,7 +1800,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
|||||||
TFL_AFAttr:$fused_activation_function
|
TFL_AFAttr:$fused_activation_function
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs AnyTensor:$output);
|
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QI16, TFL_Quint8]>:$output);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
|
|
||||||
@ -1782,7 +1832,11 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
|
|||||||
let hasOptions = 0;
|
let hasOptions = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_GpuTargetOp]> {
|
def TFL_MeanOp : TFL_Op<"mean", [
|
||||||
|
PredOpTrait<"input and output must have same element type",
|
||||||
|
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||||
|
NoSideEffect,
|
||||||
|
TFL_GpuTargetOp]> {
|
||||||
let summary = "Mean operator";
|
let summary = "Mean operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -1794,13 +1848,13 @@ def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_GpuTargetOp]> {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8, TFL_Uint8]>:$input,
|
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$input,
|
||||||
TFL_TensorOf<[I32, I64]>:$axis,
|
TFL_TensorOf<[I32, I64]>:$axis,
|
||||||
BoolAttr:$keep_dims
|
BoolAttr:$keep_dims
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$output);
|
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$output);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
let customOption = "ReducerOptions";
|
let customOption = "ReducerOptions";
|
||||||
@ -1821,14 +1875,14 @@ def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> {
|
|||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[I32, I64]>:$indices,
|
TFL_TensorOf<[I32, I64]>:$indices,
|
||||||
TFL_I32Tensor:$depth,
|
TFL_I32Tensor:$depth,
|
||||||
TFL_TensorOf<[F32, I32, I64, I1]>:$on_value,
|
TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$on_value,
|
||||||
TFL_TensorOf<[F32, I32, I64, I1]>:$off_value,
|
TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$off_value,
|
||||||
|
|
||||||
I32Attr:$axis
|
I32Attr:$axis
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, I32, I64, I1]>:$output
|
TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
@ -2032,7 +2086,11 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
|
|||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
def TFL_PackOp : TFL_Op<"pack", [
|
||||||
|
PredOpTrait<"values and output must have same element type",
|
||||||
|
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||||
|
NoSideEffect,
|
||||||
|
SameOperandsAndResultsScale]> {
|
||||||
let summary = "Packs a list of tensors along a dimension into one tensor";
|
let summary = "Packs a list of tensors along a dimension into one tensor";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2063,14 +2121,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$values,
|
TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$values,
|
||||||
|
|
||||||
I32Attr:$values_count,
|
Confined<I32Attr, [IntPositive]>:$values_count,
|
||||||
I32Attr:$axis
|
I32Attr:$axis
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$output
|
TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
@ -2081,8 +2139,11 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def TFL_PadOp : TFL_Op<"pad", [
|
def TFL_PadOp : TFL_Op<"pad", [
|
||||||
|
PredOpTrait<"input and output must have same element type",
|
||||||
|
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
SameOperandsAndResultsScale,
|
SameOperandsAndResultsScale,
|
||||||
|
TFL_OperandHasRankAtMost<0, 4>,
|
||||||
TFL_OperandHasRank<1, 2>,
|
TFL_OperandHasRank<1, 2>,
|
||||||
TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
||||||
TFL_GpuTargetOp]> {
|
TFL_GpuTargetOp]> {
|
||||||
@ -2113,22 +2174,25 @@ def TFL_PadOp : TFL_Op<"pad", [
|
|||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
|
let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
||||||
TFL_I32OrI64Tensor:$padding);
|
TFL_I32OrI64Tensor:$padding);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output);
|
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_PadV2Op : TFL_Op<"padv2", [
|
def TFL_PadV2Op : TFL_Op<"padv2", [
|
||||||
|
PredOpTrait<"input and output must have same element type",
|
||||||
|
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
SameOperandsAndResultsScale,
|
SameOperandsAndResultsScale,
|
||||||
|
TFL_OperandHasRankAtMost<0, 4>,
|
||||||
TFL_OperandHasRank<1, 2>,
|
TFL_OperandHasRank<1, 2>,
|
||||||
TFL_OperandHasRank<2, 0>,
|
TFL_OperandHasRank<2, 0>,
|
||||||
TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
||||||
PredOpTrait<"input and constant value operands must have same element type",
|
PredOpTrait<"input and constant value operands must have same element type",
|
||||||
TCopVTEtAreSameAt<[0, 2]>>]> {
|
TFL_TCopVTEtAreSameAt<0, 2>>]> {
|
||||||
let summary = "Padding operator v2";
|
let summary = "Padding operator v2";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2159,11 +2223,11 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
|
ins TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$input,
|
||||||
TFL_I32OrI64Tensor:$padding,
|
TFL_I32OrI64Tensor:$padding,
|
||||||
TFL_TensorOf<[F32, I8, I32, I64]>:$constant_values);
|
TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$constant_values);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output);
|
let results = (outs TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$output);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
}
|
}
|
||||||
@ -2191,9 +2255,21 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape,
|
|||||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect,
|
def TFL_PReluOp : TFL_Op<"prelu", [
|
||||||
|
NoSideEffect,
|
||||||
|
ResultsBroadcastableShape,
|
||||||
TFL_GpuTargetOp,
|
TFL_GpuTargetOp,
|
||||||
SameOperandsAndResultsScale]> {
|
TFL_OperandHasRankAtMost<0, 4>,
|
||||||
|
TFL_OperandHasRankAtMost<1, 4>,
|
||||||
|
BinaryOpSameElementTypeConstraint,
|
||||||
|
PredOpTrait<"input and output must have the same element type",
|
||||||
|
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||||
|
PredOpTrait<"'alpha' should have one less rank than 'input'.",
|
||||||
|
Or<[TFL_OperandIsUnrankedPred<0>,
|
||||||
|
TFL_OperandIsUnrankedPred<1>,
|
||||||
|
CPred<"$_op.getOperand(0).getType().cast<ShapedType>().getRank() == "
|
||||||
|
"$_op.getOperand(1).getType().cast<ShapedType>().getRank() "
|
||||||
|
"+ 1">]>>]> {
|
||||||
let summary = "Parameterized Relu operator";
|
let summary = "Parameterized Relu operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2206,11 +2282,11 @@ def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect,
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins TFL_TensorOf<[F32, QUI8]>:$input,
|
ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input,
|
||||||
TFL_TensorOf<[F32, QUI8]>:$alpha
|
TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$alpha
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, QUI8]>:$output);
|
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output);
|
||||||
|
|
||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
}
|
}
|
||||||
@ -2887,7 +2963,7 @@ def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [
|
|||||||
SameOperandsAndResultsScale,
|
SameOperandsAndResultsScale,
|
||||||
PredOpTrait<"input and output must have same element type",
|
PredOpTrait<"input and output must have same element type",
|
||||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||||
TFL_OperandHasRankLessThanOrEqualTo<0, 4>
|
TFL_OperandHasRankAtMost<0, 4>
|
||||||
]> {
|
]> {
|
||||||
let summary = "DepthToSpace operator";
|
let summary = "DepthToSpace operator";
|
||||||
|
|
||||||
@ -3224,7 +3300,7 @@ def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
|
|||||||
ElementsAttr:$value
|
ElementsAttr:$value
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs AnyTensor:$output);
|
let results = (outs TFL_TensorOf<[QUI8, QI8, QI16, QUI16, TFL_Quint8]>:$output);
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"OpBuilder &, OperationState &state, TypeAttr qtype, Attribute value",
|
"OpBuilder &, OperationState &state, TypeAttr qtype, Attribute value",
|
||||||
@ -3849,7 +3925,7 @@ def TFL_NumericVerifyOp : Op<TFL_Dialect, "NumericVerify", [
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[QI8, QUI8, QI16, QUI16]>:$input,
|
TFL_TensorOf<[QI8, QUI8, QI16, F16, TFL_Quint8]>:$input,
|
||||||
TFL_TensorOf<[F32]>:$ref,
|
TFL_TensorOf<[F32]>:$ref,
|
||||||
|
|
||||||
// Attributes
|
// Attributes
|
||||||
|
@ -146,6 +146,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
|||||||
saved_model_exported_names.begin(), saved_model_exported_names.end());
|
saved_model_exported_names.begin(), saved_model_exported_names.end());
|
||||||
absl::Span<std::string> exported_names(exported_names_in_vector);
|
absl::Span<std::string> exported_names(exported_names_in_vector);
|
||||||
|
|
||||||
|
if (exported_names.size() != 1) {
|
||||||
|
return errors::Unimplemented("Only support a single exported name.");
|
||||||
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(auto module,
|
TF_ASSIGN_OR_RETURN(auto module,
|
||||||
ImportSavedModel(model_flags.saved_model_dir(),
|
ImportSavedModel(model_flags.saved_model_dir(),
|
||||||
model_flags.saved_model_version(), tags,
|
model_flags.saved_model_version(), tags,
|
||||||
|
@ -573,7 +573,7 @@ func @testLogistic(tensor<1x2x3x4x5xf32>) -> tensor<1x2x3x4x5xf32> {
|
|||||||
// test invalid Logistic input
|
// test invalid Logistic input
|
||||||
func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
|
func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
|
||||||
^bb0(%arg0: tensor<?xi32>):
|
^bb0(%arg0: tensor<?xi32>):
|
||||||
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or QUI16 type values}}
|
// expected-error @+1 {{'tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or TFLite quint8 type values, but got 'tensor<?xi32>'}}
|
||||||
%0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
|
%0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
|
||||||
return %0#0 : tensor<?xi32>
|
return %0#0 : tensor<?xi32>
|
||||||
}
|
}
|
||||||
@ -1252,10 +1252,10 @@ func @testOneHot(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<f32>, %
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<*xi8> {
|
func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<*xi16> {
|
||||||
// expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer values}}
|
// expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer or 8-bit signless integer or 8-bit unsigned integer values, but got 'tensor<*xi16>'}}
|
||||||
%0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<*xi8>
|
%0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<*xi16>
|
||||||
return %0 : tensor<*xi8>
|
return %0 : tensor<*xi16>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
@ -1489,7 +1489,8 @@ func @testEmbeddingLookupValueAndResultElementTypeTraitFailed(%arg0 : tensor<?xi
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @testQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>) -> tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>> {
|
func @testWrongQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>) -> tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>> {
|
||||||
|
// expected-error @+1 {{'tfl.local_response_normalization' op operand #0 must be tensor of 32-bit float values, but got 'tensor<1x56x56x192x!quant.uniform<u8:f32, 2.000000e-02>>'}}
|
||||||
%0 = "tfl.local_response_normalization"(%arg0) {alpha = 9.99999974E-5 : f32, beta = 5.000000e-01 : f32, bias = 2.000000e+00 : f32, radius = 5 : i32} : (tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>) -> tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>
|
%0 = "tfl.local_response_normalization"(%arg0) {alpha = 9.99999974E-5 : f32, beta = 5.000000e-01 : f32, bias = 2.000000e+00 : f32, radius = 5 : i32} : (tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>) -> tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>
|
||||||
return %0 : tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>
|
return %0 : tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>
|
||||||
}
|
}
|
||||||
@ -1523,32 +1524,32 @@ func @testDepthToSpaceInvalidOutputType(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @testPReluWrongOutputRank(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<1x1x10xf32>) -> tensor<10x10x10xf32> {
|
func @testPReluWrongOutputRank(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<10x10x10x10xf32>) -> tensor<10x10xf32> {
|
||||||
// expected-error @+1 {{'input' and 'output' should have the same rank}}
|
// expected-error @+1 {{'tfl.prelu' op result type '10x10' not broadcast compatible with broadcasted operands's shapes '10x10x10x10'}}
|
||||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<10x10x10x10xf32>, tensor<1x1x10xf32>) -> tensor<10x10x10xf32>
|
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<10x10x10x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32>
|
||||||
return %0 : tensor<10x10x10xf32>
|
return %0 : tensor<10x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @testPReluWrongOutputShape(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32> {
|
func @testPReluWrongOutputShape(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32> {
|
||||||
// expected-error @+1 {{'input' and 'output' should have the same shape}}
|
// expected-error @+1 {{'tfl.prelu' op result type '1x2x3x5' not broadcast compatible with broadcasted operands's shapes '1x2x3x4'}}
|
||||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32>
|
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32>
|
||||||
return %0 : tensor<1x2x3x5xf32>
|
return %0 : tensor<1x2x3x5xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32> {
|
func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<7x3x2x14xf32>) -> tensor<7x3x2x14xf32> {
|
||||||
// expected-error @+1 {{'alpha' should have one less rank than 'input'.}}
|
// expected-error @+1 {{'alpha' should have one less rank than 'input'.}}
|
||||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x3x2x14xf32>, tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32>
|
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x3x2x14xf32>, tensor<7x3x2x14xf32>) -> tensor<7x3x2x14xf32>
|
||||||
return %0 : tensor<7x3x2x14xf32>
|
return %0 : tensor<7x3x2x14xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @testPReluInvalidBroadcast(%arg0: tensor<15x14x2x14xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32> {
|
func @testPReluInvalidBroadcast(%arg0: tensor<15x14x2x14xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32> {
|
||||||
// expected-error @+1 {{'alpha' is not broadcastable at dimension 2.}}
|
// expected-error @+1 {{'tfl.prelu' op operands don't have broadcast-compatible shapes}}
|
||||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<15x14x2x14xf32>, tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32>
|
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<15x14x2x14xf32>, tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32>
|
||||||
return %0 : tensor<15x14x2x14xf32>
|
return %0 : tensor<15x14x2x14xf32>
|
||||||
}
|
}
|
||||||
|
@ -160,6 +160,11 @@ int main(int argc, char **argv) {
|
|||||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||||
absl::Span<std::string> exported_names(exported_names_vector);
|
absl::Span<std::string> exported_names(exported_names_vector);
|
||||||
|
|
||||||
|
if (exported_names.size() != 1) {
|
||||||
|
llvm::errs() << "There should be only one exported name";
|
||||||
|
return kTrFailure;
|
||||||
|
}
|
||||||
|
|
||||||
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
|
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
|
||||||
tags, exported_names, &context);
|
tags, exported_names, &context);
|
||||||
} else {
|
} else {
|
||||||
|
@ -174,7 +174,7 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
|||||||
return module;
|
return module;
|
||||||
} else if (saved_model_version == 1) {
|
} else if (saved_model_version == 1) {
|
||||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||||
input_filename, tags, context);
|
input_filename, tags, exported_names, context);
|
||||||
|
|
||||||
if (!module)
|
if (!module)
|
||||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||||
|
@ -12,6 +12,22 @@ cc_library(
|
|||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||||
|
# (yongtang) The graph_optimization_pass_registration needs to be part
|
||||||
|
# of a shared object that will be loaded whenever `import tensorflow`
|
||||||
|
# is run. The natural place is libtensorflow_framework.so.
|
||||||
|
# While adding graph_optimization_pass_registration to
|
||||||
|
# libtensorflow_framework.so is possible with some modification in
|
||||||
|
# dependency, many tests will fail due to multiple copies of LLVM.
|
||||||
|
# See https://github.com/tensorflow/tensorflow/pull/39231 for details.
|
||||||
|
# Alternatively, we place graph_optimization_pass_registration here
|
||||||
|
# because:
|
||||||
|
# - tensorflow/python/_pywrap_mlir.so already depends on LLVM anyway
|
||||||
|
# - tensorflow/python/_pywrap_mlir.so always loaded as part of python
|
||||||
|
# binding
|
||||||
|
# TODO: It might be still preferrable to place graph_optimization_pass
|
||||||
|
# as part of the libtensorflow_framework.so, as it is the central
|
||||||
|
# place for core related components.
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:graph_optimization_pass_registration",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:import_utils",
|
"//tensorflow/compiler/mlir/tensorflow:import_utils",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
@ -112,7 +112,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
|
|||||||
// Convert the SavedModelBundle to an MLIR module.
|
// Convert the SavedModelBundle to an MLIR module.
|
||||||
|
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
|
auto module_or = ConvertSavedModelV1ToMlir(bundle, {}, &context);
|
||||||
if (!module_or.status().ok()) {
|
if (!module_or.status().ok()) {
|
||||||
Set_TF_Status_from_Status(status, module_or.status());
|
Set_TF_Status_from_Status(status, module_or.status());
|
||||||
return "// error";
|
return "// error";
|
||||||
|
41
tensorflow/compiler/mlir/python/mlir_wrapper/BUILD
Normal file
41
tensorflow/compiler/mlir/python/mlir_wrapper/BUILD
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||||
|
|
||||||
|
package(licenses = ["notice"])
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "mlir_wrapper",
|
||||||
|
srcs = [
|
||||||
|
"attrs.cc",
|
||||||
|
"basic_classes.cc",
|
||||||
|
"builders.cc",
|
||||||
|
"mlir_wrapper.cc",
|
||||||
|
"mlir_wrapper.h",
|
||||||
|
"ops.cc",
|
||||||
|
"types.cc",
|
||||||
|
],
|
||||||
|
module_name = "mlir_wrapper",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||||
|
"//tensorflow/python:pybind11_lib",
|
||||||
|
"//tensorflow/python:pybind11_status",
|
||||||
|
"@llvm-project//llvm:support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "filecheck_wrapper",
|
||||||
|
srcs = ["filecheck_wrapper.cc"],
|
||||||
|
module_name = "filecheck_wrapper",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:pybind11_lib",
|
||||||
|
"//tensorflow/python:pybind11_status",
|
||||||
|
"@llvm-project//llvm:support",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
25
tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc
Normal file
25
tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Types.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||||
|
|
||||||
|
void init_attrs(py::module& m) {
|
||||||
|
py::class_<mlir::Attribute>(m, "Attribute");
|
||||||
|
py::class_<mlir::IntegerAttr, mlir::Attribute>(m, "IntegerAttr")
|
||||||
|
.def("get",
|
||||||
|
py::overload_cast<mlir::Type, int64_t>(&mlir::IntegerAttr::get));
|
||||||
|
}
|
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "llvm/Support/FileCheck.h"
|
||||||
|
#include "mlir/IR/Block.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Location.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Region.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||||
|
|
||||||
|
void init_basic_classes(py::module& m) {
|
||||||
|
py::class_<mlir::MLIRContext>(m, "MLIRContext").def(py::init<>());
|
||||||
|
|
||||||
|
py::class_<mlir::Location>(m, "Location");
|
||||||
|
|
||||||
|
py::class_<mlir::UnknownLoc>(m, "UnknownLoc")
|
||||||
|
.def("get", &mlir::UnknownLoc::get);
|
||||||
|
|
||||||
|
py::class_<mlir::Region>(m, "Region")
|
||||||
|
.def("back", &mlir::Region::back, py::return_value_policy::reference)
|
||||||
|
.def("front", &mlir::Region::front, py::return_value_policy::reference)
|
||||||
|
.def("add_block", [](mlir::Region& r) { r.push_back(new mlir::Block); })
|
||||||
|
.def("push_back", &mlir::Region::push_back)
|
||||||
|
.def("size", [](mlir::Region& r) { return r.getBlocks().size(); })
|
||||||
|
.def("front", &mlir::Region::front, py::return_value_policy::reference);
|
||||||
|
py::class_<mlir::Block::iterator>(m, "Block_Iterator");
|
||||||
|
py::class_<mlir::Block>(m, "Block")
|
||||||
|
.def("new", ([]() { return new mlir::Block; }),
|
||||||
|
py::return_value_policy::reference)
|
||||||
|
.def("end", &mlir::Block::end)
|
||||||
|
.def("addArgument", &mlir::Block::addArgument);
|
||||||
|
|
||||||
|
py::class_<mlir::Value>(m, "Value").def("getType", &mlir::Value::getType);
|
||||||
|
py::class_<mlir::OpResult, mlir::Value>(m, "OpResult");
|
||||||
|
py::class_<mlir::BlockArgument, mlir::Value>(m, "BlockArgument");
|
||||||
|
}
|
51
tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc
Normal file
51
tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||||
|
|
||||||
|
void init_builders(py::module& m) {
|
||||||
|
py::class_<mlir::Builder>(m, "Builder")
|
||||||
|
.def(py::init<mlir::MLIRContext*>())
|
||||||
|
.def("getFunctionType",
|
||||||
|
[](mlir::Builder& b, std::vector<mlir::Type> inputs,
|
||||||
|
std::vector<mlir::Type> outputs) {
|
||||||
|
return b.getFunctionType(llvm::ArrayRef<mlir::Type>(inputs),
|
||||||
|
llvm::ArrayRef<mlir::Type>(outputs));
|
||||||
|
});
|
||||||
|
py::class_<mlir::OpBuilder>(m, "OpBuilder")
|
||||||
|
.def(py::init<mlir::MLIRContext*>())
|
||||||
|
.def(py::init<mlir::Region&>())
|
||||||
|
.def(py::init<mlir::Operation*>())
|
||||||
|
.def(py::init<mlir::Block*, mlir::Block::iterator>())
|
||||||
|
.def("getUnknownLoc", &mlir::OpBuilder::getUnknownLoc)
|
||||||
|
.def("setInsertionPoint",
|
||||||
|
py::overload_cast<mlir::Block*, mlir::Block::iterator>(
|
||||||
|
&mlir::OpBuilder::setInsertionPoint))
|
||||||
|
.def("saveInsertionPoint", &mlir::OpBuilder::saveInsertionPoint)
|
||||||
|
.def("restoreInsertionPoint", &mlir::OpBuilder::restoreInsertionPoint)
|
||||||
|
.def(
|
||||||
|
"createOperation",
|
||||||
|
[](mlir::OpBuilder& opb, mlir::OperationState& state) {
|
||||||
|
return opb.createOperation(state);
|
||||||
|
},
|
||||||
|
py::return_value_policy::reference)
|
||||||
|
.def("getContext", &mlir::OpBuilder::getContext,
|
||||||
|
py::return_value_policy::reference);
|
||||||
|
|
||||||
|
py::class_<mlir::OpBuilder::InsertPoint>(m, "OpBuilder_InsertionPoint")
|
||||||
|
.def("getBlock", &mlir::OpBuilder::InsertPoint::getBlock);
|
||||||
|
}
|
@ -0,0 +1,36 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "llvm/Support/FileCheck.h"
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/stl.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||||
|
|
||||||
|
PYBIND11_MODULE(filecheck_wrapper, m) {
|
||||||
|
m.def("check", [](std::string input, std::string check) {
|
||||||
|
llvm::FileCheckRequest fcr;
|
||||||
|
llvm::FileCheck fc(fcr);
|
||||||
|
llvm::SourceMgr SM = llvm::SourceMgr();
|
||||||
|
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
|
||||||
|
llvm::SMLoc());
|
||||||
|
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(check),
|
||||||
|
llvm::SMLoc());
|
||||||
|
llvm::Regex regex = fc.buildCheckPrefixRegex();
|
||||||
|
fc.readCheckFile(SM, llvm::StringRef(check), regex);
|
||||||
|
return fc.checkInput(SM, llvm::StringRef(input));
|
||||||
|
});
|
||||||
|
}
|
38
tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
Normal file
38
tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/stl.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||||
|
|
||||||
|
PYBIND11_MODULE(mlir_wrapper, m) {
|
||||||
|
m.def("registerDialects", []() {
|
||||||
|
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||||
|
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||||
|
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||||
|
});
|
||||||
|
|
||||||
|
init_basic_classes(m);
|
||||||
|
init_types(m);
|
||||||
|
init_builders(m);
|
||||||
|
init_ops(m);
|
||||||
|
init_attrs(m);
|
||||||
|
}
|
30
tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h
Normal file
30
tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
|
||||||
|
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/stl.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
void init_basic_classes(py::module& m);
|
||||||
|
void init_types(py::module& m);
|
||||||
|
void init_builders(py::module& m);
|
||||||
|
void init_ops(py::module& m);
|
||||||
|
void init_attrs(py::module& m);
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
|
194
tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc
Normal file
194
tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
|
|
||||||
|
#include "mlir/IR/Function.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
|
||||||
|
void init_ops(py::module& m) {
|
||||||
|
py::class_<mlir::Operation, std::unique_ptr<mlir::Operation, py::nodelete>>(
|
||||||
|
m, "Operation")
|
||||||
|
.def("getRegion", &mlir::Operation::getRegion,
|
||||||
|
py::return_value_policy::reference)
|
||||||
|
.def("getResult", &mlir::Operation::getResult)
|
||||||
|
.def("dump", &mlir::Operation::dump)
|
||||||
|
.def("getNumResults", &mlir::Operation::getNumResults);
|
||||||
|
|
||||||
|
py::class_<mlir::OperationState>(m, "OperationState")
|
||||||
|
.def(py::init([](mlir::Location loc, std::string name) {
|
||||||
|
return mlir::OperationState(loc, llvm::StringRef(name));
|
||||||
|
}))
|
||||||
|
.def("addTypes",
|
||||||
|
[](mlir::OperationState& state, std::vector<mlir::Type> tys) {
|
||||||
|
state.addTypes(mlir::ArrayRef<mlir::Type>(tys));
|
||||||
|
})
|
||||||
|
.def("addOperands",
|
||||||
|
[](mlir::OperationState& os, std::vector<mlir::Value> ops) {
|
||||||
|
os.addOperands(mlir::ArrayRef<mlir::Value>(ops));
|
||||||
|
})
|
||||||
|
.def("addRegion", py::overload_cast<>(&mlir::OperationState::addRegion),
|
||||||
|
py::return_value_policy::reference);
|
||||||
|
|
||||||
|
py::class_<mlir::ModuleOp>(m, "ModuleOp")
|
||||||
|
.def("create",
|
||||||
|
[](mlir::Location loc) { return mlir::ModuleOp::create(loc); })
|
||||||
|
.def("push_back",
|
||||||
|
[](mlir::ModuleOp& m, mlir::FuncOp f) { m.push_back(f); })
|
||||||
|
.def("dump", &mlir::ModuleOp::dump)
|
||||||
|
.def("getAsStr", [](mlir::ModuleOp& m) {
|
||||||
|
std::string str;
|
||||||
|
llvm::raw_string_ostream os(str);
|
||||||
|
m.print(os);
|
||||||
|
return os.str();
|
||||||
|
});
|
||||||
|
|
||||||
|
py::class_<mlir::FuncOp>(m, "FuncOp")
|
||||||
|
.def("create",
|
||||||
|
[](mlir::Location location, std::string name,
|
||||||
|
mlir::FunctionType type) {
|
||||||
|
auto func = mlir::FuncOp::create(location, name, type);
|
||||||
|
func.addEntryBlock();
|
||||||
|
return func;
|
||||||
|
})
|
||||||
|
.def(
|
||||||
|
"getBody",
|
||||||
|
[](mlir::FuncOp& f) -> mlir::Region& { return f.getBody(); },
|
||||||
|
py::return_value_policy::reference)
|
||||||
|
.def("getArguments",
|
||||||
|
[](mlir::FuncOp& f) { return f.getArguments().vec(); })
|
||||||
|
.def("getName", [](mlir::FuncOp& f) { return f.getName().str(); })
|
||||||
|
.def("getType", &mlir::FuncOp::getType);
|
||||||
|
|
||||||
|
py::class_<mlir::ReturnOp>(m, "ReturnOp")
|
||||||
|
.def("create",
|
||||||
|
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||||
|
std::vector<mlir::Value> values) -> mlir::Operation* {
|
||||||
|
return opb
|
||||||
|
.create<mlir::ReturnOp>(loc,
|
||||||
|
mlir::ArrayRef<mlir::Value>(values))
|
||||||
|
.getOperation();
|
||||||
|
});
|
||||||
|
|
||||||
|
// mlir::TF::AddOp
|
||||||
|
py::class_<mlir::TF::AddV2Op>(m, "Tf_AddV2Op")
|
||||||
|
.def("create",
|
||||||
|
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||||
|
mlir::Value y) -> mlir::Operation* {
|
||||||
|
return opb.create<mlir::TF::AddV2Op>(loc, x, y).getOperation();
|
||||||
|
});
|
||||||
|
|
||||||
|
py::class_<mlir::TF::AnyOp>(m, "Tf_AnyOp")
|
||||||
|
.def("create",
|
||||||
|
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value input,
|
||||||
|
mlir::Value reduction_indices,
|
||||||
|
bool keep_dims = false) -> mlir::Operation* {
|
||||||
|
return opb
|
||||||
|
.create<mlir::TF::AnyOp>(loc, opb.getI1Type(), input,
|
||||||
|
reduction_indices, keep_dims)
|
||||||
|
.getOperation();
|
||||||
|
});
|
||||||
|
|
||||||
|
// mlir::TF::ConstOp
|
||||||
|
py::class_<mlir::TF::ConstOp>(m, "Tf_ConstOp")
|
||||||
|
.def("create",
|
||||||
|
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||||
|
mlir::Attribute value) -> mlir::Operation* {
|
||||||
|
return opb.create<mlir::TF::ConstOp>(loc, value).getOperation();
|
||||||
|
});
|
||||||
|
|
||||||
|
// mlir::TF::EqualOp
|
||||||
|
py::class_<mlir::TF::EqualOp>(m, "Tf_EqualOp")
|
||||||
|
.def("create",
|
||||||
|
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||||
|
mlir::Value y) -> mlir::Operation* {
|
||||||
|
return opb
|
||||||
|
.create<mlir::TF::EqualOp>(loc, x, y, opb.getBoolAttr(true))
|
||||||
|
.getOperation();
|
||||||
|
});
|
||||||
|
|
||||||
|
// mlir::TF::GreaterEqualOp
|
||||||
|
py::class_<mlir::TF::GreaterEqualOp>(m, "Tf_GreaterEqualOp")
|
||||||
|
.def("create",
|
||||||
|
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||||
|
mlir::Value y) -> mlir::Operation* {
|
||||||
|
return opb.create<mlir::TF::GreaterEqualOp>(loc, x, y)
|
||||||
|
.getOperation();
|
||||||
|
});
|
||||||
|
|
||||||
|
// mlir::TF::GreaterOp
|
||||||
|
py::class_<mlir::TF::GreaterOp>(m, "Tf_GreaterOp")
|
||||||
|
.def("create",
|
||||||
|
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||||
|
mlir::Value y) -> mlir::Operation* {
|
||||||
|
return opb.create<mlir::TF::GreaterOp>(loc, x, y).getOperation();
|
||||||
|
});
|
||||||
|
|
||||||
|
// mlir::TF::LegacyCallOp
|
||||||
|
py::class_<mlir::TF::LegacyCallOp>(m, "Tf_LegacyCallOp")
|
||||||
|
.def("create",
|
||||||
|
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||||
|
std::vector<mlir::Type> output, std::vector<mlir::Value> args,
|
||||||
|
std::string f) -> mlir::Operation* {
|
||||||
|
return opb
|
||||||
|
.create<mlir::TF::LegacyCallOp>(
|
||||||
|
loc, mlir::ArrayRef<mlir::Type>(output),
|
||||||
|
mlir::ArrayRef<mlir::Value>(args), mlir::StringRef(f))
|
||||||
|
.getOperation();
|
||||||
|
});
|
||||||
|
|
||||||
|
// mlir::TF::LessEqualOp
|
||||||
|
py::class_<mlir::TF::LessEqualOp>(m, "Tf_LessEqualOp")
|
||||||
|
.def("create",
|
||||||
|
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||||
|
mlir::Value y) -> mlir::Operation* {
|
||||||
|
return opb.create<mlir::TF::LessEqualOp>(loc, x, y).getOperation();
|
||||||
|
});
|
||||||
|
|
||||||
|
// mlir::TF::LessOp
|
||||||
|
py::class_<mlir::TF::LessOp>(m, "Tf_LessOp")
|
||||||
|
.def("create",
|
||||||
|
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||||
|
mlir::Value y) -> mlir::Operation* {
|
||||||
|
return opb.create<mlir::TF::LessOp>(loc, x, y).getOperation();
|
||||||
|
});
|
||||||
|
|
||||||
|
// mlir::TF::NegOp
|
||||||
|
py::class_<mlir::TF::NegOp>(m, "Tf_NegOp")
|
||||||
|
.def("create",
|
||||||
|
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||||
|
mlir::Value x) -> mlir::Operation* {
|
||||||
|
return opb.create<mlir::TF::NegOp>(loc, x).getOperation();
|
||||||
|
});
|
||||||
|
|
||||||
|
py::class_<mlir::TF::NotEqualOp>(m, "Tf_NotEqualOp")
|
||||||
|
.def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||||
|
mlir::Value y) {
|
||||||
|
return opb
|
||||||
|
.create<mlir::TF::NotEqualOp>(
|
||||||
|
loc, x, y, mlir::BoolAttr::get(true, opb.getContext()))
|
||||||
|
.getOperation();
|
||||||
|
});
|
||||||
|
|
||||||
|
// mlir::TF::SubOp
|
||||||
|
py::class_<mlir::TF::SubOp>(m, "Tf_SubOp")
|
||||||
|
.def("create",
|
||||||
|
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||||
|
mlir::Value y) -> mlir::Operation* {
|
||||||
|
return opb.create<mlir::TF::SubOp>(loc, x, y).getOperation();
|
||||||
|
});
|
||||||
|
}
|
48
tensorflow/compiler/mlir/python/mlir_wrapper/types.cc
Normal file
48
tensorflow/compiler/mlir/python/mlir_wrapper/types.cc
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
|
|
||||||
|
void init_types(py::module& m) {
|
||||||
|
// Type
|
||||||
|
py::class_<mlir::Type> Type(m, "Type");
|
||||||
|
Type.def("getKind", &mlir::Type::getKind);
|
||||||
|
|
||||||
|
// Type Enums
|
||||||
|
py::enum_<mlir::StandardTypes::Kind>(Type, "StandardTypes_Kind")
|
||||||
|
.value("BF16", mlir::StandardTypes::BF16);
|
||||||
|
|
||||||
|
// Type Sub-classes
|
||||||
|
py::class_<mlir::FunctionType, mlir::Type>(m, "FunctionType")
|
||||||
|
.def("getResults",
|
||||||
|
[](mlir::FunctionType& ft) { return ft.getResults().vec(); });
|
||||||
|
|
||||||
|
py::class_<mlir::FloatType, mlir::Type>(m, "FloatType")
|
||||||
|
.def("get", &mlir::FloatType::get);
|
||||||
|
|
||||||
|
py::class_<mlir::IntegerType, mlir::Type>(m, "IntegerType")
|
||||||
|
.def("get", py::overload_cast<unsigned, mlir::MLIRContext*>(
|
||||||
|
&mlir::IntegerType::get));
|
||||||
|
|
||||||
|
py::class_<mlir::UnrankedTensorType, mlir::Type>(m, "UnrankedTensorType")
|
||||||
|
.def("get", &mlir::UnrankedTensorType::get);
|
||||||
|
|
||||||
|
py::class_<mlir::RankedTensorType, mlir::Type>(m, "RankedTensorType")
|
||||||
|
.def("get", [](std::vector<int64_t> shape, mlir::Type ty) {
|
||||||
|
return mlir::RankedTensorType::get(mlir::ArrayRef<int64_t>(shape), ty);
|
||||||
|
});
|
||||||
|
}
|
@ -70,9 +70,9 @@ tool_dirs = config.mlir_tf_tools_dirs + [
|
|||||||
]
|
]
|
||||||
tool_names = [
|
tool_names = [
|
||||||
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
||||||
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
|
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
||||||
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt',
|
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
||||||
'xla-opt'
|
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt'
|
||||||
]
|
]
|
||||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||||
|
@ -44,6 +44,7 @@ mlir_tf_tools_dirs = [
|
|||||||
'tensorflow/compiler/mlir',
|
'tensorflow/compiler/mlir',
|
||||||
'tensorflow/compiler/mlir/lite',
|
'tensorflow/compiler/mlir/lite',
|
||||||
'tensorflow/compiler/mlir/tensorflow',
|
'tensorflow/compiler/mlir/tensorflow',
|
||||||
|
'tensorflow/compiler/mlir/tfjs',
|
||||||
'tensorflow/compiler/mlir/xla',
|
'tensorflow/compiler/mlir/xla',
|
||||||
'tensorflow/compiler/aot',
|
'tensorflow/compiler/aot',
|
||||||
'tensorflow/compiler/xla/service/mlir_gpu',
|
'tensorflow/compiler/xla/service/mlir_gpu',
|
||||||
|
@ -36,7 +36,7 @@ filegroup(
|
|||||||
"@llvm-project//mlir:OpBaseTdFiles",
|
"@llvm-project//mlir:OpBaseTdFiles",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -556,7 +556,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":tensorflow",
|
":tensorflow",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:LoopOpsTransforms",
|
"@llvm-project//mlir:SCFTransforms",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
@ -823,6 +823,7 @@ cc_library(
|
|||||||
":mangling_util",
|
":mangling_util",
|
||||||
":tensorflow_attributes",
|
":tensorflow_attributes",
|
||||||
":tensorflow_types",
|
":tensorflow_types",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
@ -1074,7 +1075,7 @@ genrule(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||||
"ir/tf_generated_ops.td",
|
"ir/tf_generated_ops.td",
|
||||||
"ir/tf_op_base.td",
|
"ir/tf_op_base.td",
|
||||||
|
@ -192,6 +192,44 @@ retained with length 1.
|
|||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect]> {
|
||||||
|
let summary = "An Op to exchange data across TPU replicas.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
On each replica, the input is split into `split_count` blocks along
|
||||||
|
`split_dimension` and send to the other replicas given group_assignment. After
|
||||||
|
receiving `split_count` - 1 blocks from other replicas, we concatenate the
|
||||||
|
blocks along `concat_dimension` as the output.
|
||||||
|
|
||||||
|
For example, suppose there are 2 TPU replicas:
|
||||||
|
replica 0 receives input: `[[A, B]]`
|
||||||
|
replica 1 receives input: `[[C, D]]`
|
||||||
|
|
||||||
|
group_assignment=`[[0, 1]]`
|
||||||
|
concat_dimension=0
|
||||||
|
split_dimension=1
|
||||||
|
split_count=2
|
||||||
|
|
||||||
|
replica 0's output: `[[A], [C]]`
|
||||||
|
replica 1's output: `[[B], [D]]`
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||||
|
I32Tensor:$group_assignment,
|
||||||
|
|
||||||
|
I64Attr:$concat_dimension,
|
||||||
|
I64Attr:$split_dimension,
|
||||||
|
I64Attr:$split_count
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> {
|
def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||||
let summary = "Returns the argument of a complex number.";
|
let summary = "Returns the argument of a complex number.";
|
||||||
|
|
||||||
@ -1217,7 +1255,7 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect]> {
|
||||||
let summary = "Clips tensor values to a specified min and max.";
|
let summary = "Clips tensor values to a specified min and max.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -1408,6 +1446,30 @@ tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
|
|||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_ConjugateTransposeOp : TF_Op<"ConjugateTranspose", [NoSideEffect]> {
|
||||||
|
let summary = [{
|
||||||
|
Shuffle dimensions of x according to a permutation and conjugate the result.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
|
||||||
|
`y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
|
||||||
|
`y[i,j,k,...,s,t,u] == conj(x[perm[i], perm[j], perm[k],...,perm[s], perm[t], perm[u]])`
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_Tensor:$x,
|
||||||
|
TF_I32OrI64Tensor:$perm
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Tensor:$y
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
TF_DerivedOperandTypeAttr Tperm = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect, TF_LayoutSensitiveInterface]> {
|
def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect, TF_LayoutSensitiveInterface]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Computes a 2-D convolution given 4-D `input` and `filter` tensors.
|
Computes a 2-D convolution given 4-D `input` and `filter` tensors.
|
||||||
@ -1682,7 +1744,28 @@ Given an input tensor, this function computes hyperbolic cosine of every
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [AllTypesMatch<["input", "output"]>, NoSideEffect]> {
|
def TF_CrossOp : TF_Op<"Cross", [NoSideEffect]> {
|
||||||
|
let summary = "Compute the pairwise cross product.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
`a` and `b` must be the same shape; they can either be simple 3-element vectors,
|
||||||
|
or any shape where the innermost dimension is 3. In the latter case, each pair
|
||||||
|
of corresponding 3-element vectors is cross-multiplied independently.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_IntOrFpTensor:$a,
|
||||||
|
TF_IntOrFpTensor:$b
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_IntOrFpTensor:$product
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> {
|
||||||
let summary = "An Op to sum inputs across replicated TPU instances.";
|
let summary = "An Op to sum inputs across replicated TPU instances.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -1706,7 +1789,7 @@ and `B, D, F, H` as group 1. Thus we get the outputs:
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_CumsumOp : TF_Op<"Cumsum", [AllTypesMatch<["x", "out"]>, NoSideEffect]> {
|
def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
|
||||||
let summary = "Compute the cumulative sum of the tensor `x` along `axis`.";
|
let summary = "Compute the cumulative sum of the tensor `x` along `axis`.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -3256,8 +3339,8 @@ Gather slices from `params` axis `axis` according to `indices`.
|
|||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).
|
`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).
|
||||||
Produces an output tensor with shape `params.shape[:axis] + indices.shape +
|
Produces an output tensor with shape `params.shape[:axis] +
|
||||||
params.shape[axis + 1:]` where:
|
indices.shape[batch_dims:] + params.shape[axis + 1:]` where:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Scalar indices (output is rank(params) - 1).
|
# Scalar indices (output is rank(params) - 1).
|
||||||
@ -3542,6 +3625,31 @@ tf.imag(input) ==> [4.75, 5.75]
|
|||||||
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
|
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
|
||||||
|
let summary = [{
|
||||||
|
Create a copy of `x` with the updated specified rows 'i' with values 'v'.
|
||||||
|
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Creates a copy of tensor 'x' and updates the columns specified in tensor 'i'
|
||||||
|
with the values 'v'. Originally this function was mutative however for
|
||||||
|
compilation we make this operation create / operate on a copy.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_Tensor:$x,
|
||||||
|
I32Tensor:$i,
|
||||||
|
TF_Tensor:$v
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Tensor:$y
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_InvOp : TF_Op<"Inv", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_InvOp : TF_Op<"Inv", [NoSideEffect, SameOperandsAndResultType]> {
|
||||||
let summary = "Computes the reciprocal of x element-wise.";
|
let summary = "Computes the reciprocal of x element-wise.";
|
||||||
|
|
||||||
@ -4242,7 +4350,7 @@ cublas.
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [AllTypesMatch<["input", "band"]>, NoSideEffect]> {
|
def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, TF_AllTypesMatch<["input", "band"]>]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
||||||
}];
|
}];
|
||||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
#define TF_OP_BASE
|
#define TF_OP_BASE
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Interfaces/SideEffects.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td"
|
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -70,6 +70,16 @@ class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
|
|||||||
"$_op.getOperand(" # opId # ").getType(), "
|
"$_op.getOperand(" # opId # ").getType(), "
|
||||||
"$_op.getResult(" # resId # ").getType())">]>;
|
"$_op.getResult(" # resId # ").getType())">]>;
|
||||||
|
|
||||||
|
|
||||||
|
class TF_AllTypesMatchPred<list<string> values> :
|
||||||
|
CPred<"TF::AreCastCompatible(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">;
|
||||||
|
|
||||||
|
class TF_AllTypesMatch<list<string> names> :
|
||||||
|
PredOpTrait<
|
||||||
|
"all of {" # StrJoin<names>.result # "} have dynamically equal types ",
|
||||||
|
TF_AllTypesMatchPred<
|
||||||
|
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TensorFlow op definitions
|
// TensorFlow op definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -110,47 +110,6 @@ static inline bool HasRankAtMost(Value value, int64_t rank) {
|
|||||||
return !type || type.getRank() <= rank;
|
return !type || type.getRank() <= rank;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if the given pair of TensorFlow types can be cast to one
|
|
||||||
// another. In other words, a single run-time value is legal for both the types.
|
|
||||||
// For example, tensor<*xf32> and tensor<3xf32> are cast compatible.
|
|
||||||
static bool AreCastCompatible(Type a, Type b) {
|
|
||||||
if (TensorCastOp::areCastCompatible(a, b)) return true;
|
|
||||||
|
|
||||||
// Resource types may optionally contain subtypes information that does not
|
|
||||||
// match. Check subtypes compatibility when possible, otherwise treat them as
|
|
||||||
// compatible.
|
|
||||||
auto a_or_element_type = getElementTypeOrSelf(a);
|
|
||||||
auto b_or_element_type = getElementTypeOrSelf(b);
|
|
||||||
|
|
||||||
auto a_kind = a_or_element_type.getKind();
|
|
||||||
auto b_kind = b_or_element_type.getKind();
|
|
||||||
|
|
||||||
if (a_kind == TensorFlowTypes::RESOURCE &&
|
|
||||||
b_kind == TensorFlowTypes::RESOURCE) {
|
|
||||||
auto a_resource_type = a_or_element_type.dyn_cast<ResourceType>();
|
|
||||||
auto b_resource_type = b_or_element_type.dyn_cast<ResourceType>();
|
|
||||||
bool a_has_subtype = !a_resource_type.getSubtypes().empty();
|
|
||||||
bool b_has_subtype = !b_resource_type.getSubtypes().empty();
|
|
||||||
|
|
||||||
if (!a_has_subtype || !b_has_subtype) return true;
|
|
||||||
|
|
||||||
assert(a_resource_type.getSubtypes().size() <= 1 &&
|
|
||||||
"Resource type must have at most one subtype");
|
|
||||||
assert(b_resource_type.getSubtypes().size() <= 1 &&
|
|
||||||
"Resource type must have at most one subtype");
|
|
||||||
|
|
||||||
return TensorCastOp::areCastCompatible(
|
|
||||||
a_resource_type.getSubtypes().front(),
|
|
||||||
b_resource_type.getSubtypes().front());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Variant types may optionally contain subtypes information that need not
|
|
||||||
// match. It is also not possible to compare subtypes for compatibility as
|
|
||||||
// their interpretation depends on the ops operating on them. So, accept all
|
|
||||||
// pairs of variant types.
|
|
||||||
return a_kind == TensorFlowTypes::VARIANT &&
|
|
||||||
b_kind == TensorFlowTypes::VARIANT;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
|
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
|
||||||
return dim_or_rank == -1;
|
return dim_or_rank == -1;
|
||||||
@ -984,11 +943,10 @@ void ConstOp::build(OpBuilder &builder, OperationState &result, Type type,
|
|||||||
|
|
||||||
LogicalResult ConstOp::inferReturnTypes(
|
LogicalResult ConstOp::inferReturnTypes(
|
||||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||||
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
for (NamedAttribute named_attr : attributes) {
|
auto value = attributes.get("value");
|
||||||
if (named_attr.first.strref() != "value") continue;
|
if (!value) return emitOptionalError(location, "missing attribute 'value'");
|
||||||
auto value = named_attr.second;
|
|
||||||
if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
|
if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
|
||||||
inferredReturnTypes.assign({elem_attr.getType()});
|
inferredReturnTypes.assign({elem_attr.getType()});
|
||||||
return success();
|
return success();
|
||||||
@ -996,8 +954,6 @@ LogicalResult ConstOp::inferReturnTypes(
|
|||||||
return emitOptionalError(location,
|
return emitOptionalError(location,
|
||||||
"attribute 'value' failed to satisfy constraint: "
|
"attribute 'value' failed to satisfy constraint: "
|
||||||
"constant vector/tensor");
|
"constant vector/tensor");
|
||||||
}
|
|
||||||
return emitOptionalError(location, "missing attribute 'value'");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1416,7 +1372,7 @@ static LogicalResult Verify(DynamicStitchOp op) {
|
|||||||
auto expected_out_ty =
|
auto expected_out_ty =
|
||||||
RankedTensorType::get(expected_shape, out_ty.getElementType());
|
RankedTensorType::get(expected_shape, out_ty.getElementType());
|
||||||
|
|
||||||
if (!AreCastCompatible(out_ty, expected_out_ty)) {
|
if (!AreCastCompatible({out_ty, expected_out_ty})) {
|
||||||
return op.emitOpError() << "has invalid output type; should be "
|
return op.emitOpError() << "has invalid output type; should be "
|
||||||
"compatible with inferred type "
|
"compatible with inferred type "
|
||||||
<< expected_out_ty;
|
<< expected_out_ty;
|
||||||
@ -1817,14 +1773,14 @@ static LogicalResult Verify(IfOp op) {
|
|||||||
for (unsigned i = 0; i < expectedNumInputs; ++i) {
|
for (unsigned i = 0; i < expectedNumInputs; ++i) {
|
||||||
auto operandType = op.getOperand(i + 1).getType().cast<TensorType>();
|
auto operandType = op.getOperand(i + 1).getType().cast<TensorType>();
|
||||||
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
|
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
|
||||||
if (!AreCastCompatible(operandType, thenInputType))
|
if (!AreCastCompatible({operandType, thenInputType}))
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
llvm::formatv("then branch input type {0} is incompatible with "
|
llvm::formatv("then branch input type {0} is incompatible with "
|
||||||
"operand type {1} at index {2}",
|
"operand type {1} at index {2}",
|
||||||
thenInputType, operandType, i));
|
thenInputType, operandType, i));
|
||||||
|
|
||||||
auto elseInputType = elseFuncType.getInput(i).cast<TensorType>();
|
auto elseInputType = elseFuncType.getInput(i).cast<TensorType>();
|
||||||
if (!AreCastCompatible(operandType, elseInputType))
|
if (!AreCastCompatible({operandType, elseInputType}))
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
llvm::formatv("else branch input type {0} is incompatible with "
|
llvm::formatv("else branch input type {0} is incompatible with "
|
||||||
"operand type {1} at index {2}",
|
"operand type {1} at index {2}",
|
||||||
@ -1832,7 +1788,7 @@ static LogicalResult Verify(IfOp op) {
|
|||||||
|
|
||||||
// If branches have incompatible input types that means that no tensor can
|
// If branches have incompatible input types that means that no tensor can
|
||||||
// serve as input to both the functions. Hence, the op is invalid.
|
// serve as input to both the functions. Hence, the op is invalid.
|
||||||
if (!AreCastCompatible(thenInputType, elseInputType))
|
if (!AreCastCompatible({thenInputType, elseInputType}))
|
||||||
return op.emitError(llvm::formatv(
|
return op.emitError(llvm::formatv(
|
||||||
"branches inputs have incompatible types {0} and {1} at index {2}",
|
"branches inputs have incompatible types {0} and {1} at index {2}",
|
||||||
thenInputType, elseInputType, i));
|
thenInputType, elseInputType, i));
|
||||||
@ -1848,14 +1804,14 @@ static LogicalResult Verify(IfOp op) {
|
|||||||
for (unsigned i = 0; i < expectedNumResults; ++i) {
|
for (unsigned i = 0; i < expectedNumResults; ++i) {
|
||||||
auto resultType = op.getResult(i).getType().cast<TensorType>();
|
auto resultType = op.getResult(i).getType().cast<TensorType>();
|
||||||
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
|
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
|
||||||
if (!AreCastCompatible(thenResultType, resultType))
|
if (!AreCastCompatible({thenResultType, resultType}))
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
llvm::formatv("then branch result type {0} is incompatible with op "
|
llvm::formatv("then branch result type {0} is incompatible with op "
|
||||||
"result type {1} at index {2}",
|
"result type {1} at index {2}",
|
||||||
thenResultType, resultType, i));
|
thenResultType, resultType, i));
|
||||||
|
|
||||||
auto elseResultType = elseFuncType.getResult(i).cast<TensorType>();
|
auto elseResultType = elseFuncType.getResult(i).cast<TensorType>();
|
||||||
if (!AreCastCompatible(elseResultType, resultType))
|
if (!AreCastCompatible({elseResultType, resultType}))
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
llvm::formatv("else branch result type {0} is incompatible with op "
|
llvm::formatv("else branch result type {0} is incompatible with op "
|
||||||
"result type {1} at index {2}",
|
"result type {1} at index {2}",
|
||||||
@ -3792,7 +3748,7 @@ static LogicalResult Verify(WhileOp op) {
|
|||||||
auto aType = a.second[idx];
|
auto aType = a.second[idx];
|
||||||
auto bType = b.second[idx];
|
auto bType = b.second[idx];
|
||||||
|
|
||||||
if (!AreCastCompatible(aType, bType))
|
if (!AreCastCompatible({aType, bType}))
|
||||||
return op.emitError(llvm::formatv(
|
return op.emitError(llvm::formatv(
|
||||||
"{0} type {1} is incompatible with {2} type {3} at index {4}",
|
"{0} type {1} is incompatible with {2} type {3} at index {4}",
|
||||||
a.first, aType, b.first, bType, idx));
|
a.first, aType, b.first, bType, idx));
|
||||||
|
@ -28,6 +28,134 @@ llvm::Optional<llvm::ArrayRef<int64_t>> GetShape(mlir::Value value) {
|
|||||||
if (shaped_type.hasRank()) return shaped_type.getShape();
|
if (shaped_type.hasRank()) return shaped_type.getShape();
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Merges cast compatible shapes and returns a more refined shape. The two
|
||||||
|
// shapes are cast compatible if they have the same rank and at each dimension,
|
||||||
|
// either both have same size or one of them is dynamic. Returns false if the
|
||||||
|
// given shapes are not cast compatible. The refined shape is same or more
|
||||||
|
// precise than the two input shapes.
|
||||||
|
bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,
|
||||||
|
llvm::ArrayRef<int64_t> b_shape,
|
||||||
|
llvm::SmallVectorImpl<int64_t>* refined_shape) {
|
||||||
|
if (a_shape.size() != b_shape.size()) return false;
|
||||||
|
int64_t rank = a_shape.size();
|
||||||
|
refined_shape->reserve(rank);
|
||||||
|
for (auto dims : llvm::zip(a_shape, b_shape)) {
|
||||||
|
int64_t dim1 = std::get<0>(dims);
|
||||||
|
int64_t dim2 = std::get<1>(dims);
|
||||||
|
|
||||||
|
if (mlir::ShapedType::isDynamic(dim1)) {
|
||||||
|
refined_shape->push_back(dim2);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (mlir::ShapedType::isDynamic(dim2)) {
|
||||||
|
refined_shape->push_back(dim1);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (dim1 == dim2) {
|
||||||
|
refined_shape->push_back(dim1);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given two types `a` and `b`, returns a refined type which is cast compatible
|
||||||
|
// with both `a` and `b` and is equal to or more precise than both of them. It
|
||||||
|
// returns empty Type if the input types are not cast compatible.
|
||||||
|
//
|
||||||
|
// The two types are considered cast compatible if they have dynamically equal
|
||||||
|
// shapes and element type. For element types that do not have subtypes, they
|
||||||
|
// must be equal. However for TensorFlow types such as Resource and Variant,
|
||||||
|
// that also have subtypes, we recursively check for subtype compatibilty for
|
||||||
|
// Resource types and assume all variant types are cast compatible. If either
|
||||||
|
// one of `a` or `b` have empty subtypes, they are considered cast compatible.
|
||||||
|
//
|
||||||
|
// The returned type is same or more precise than the input types. For example,
|
||||||
|
// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
|
||||||
|
// tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
|
||||||
|
//
|
||||||
|
// Provides option to ignore ref types on 'a'. This is useful for TF ops that
|
||||||
|
// might allow operands to either be same as result type or be a ref type
|
||||||
|
// corresponding to it.
|
||||||
|
mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
|
||||||
|
bool may_ignore_ref_type_a) {
|
||||||
|
// Fast path if everything is equal.
|
||||||
|
if (a == b) return b;
|
||||||
|
|
||||||
|
auto a_tt = a.dyn_cast<mlir::TensorType>();
|
||||||
|
auto b_tt = b.dyn_cast<mlir::TensorType>();
|
||||||
|
|
||||||
|
// If only one of a or b is a tensor type, they are incompatible.
|
||||||
|
if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
|
||||||
|
|
||||||
|
// For non-tensor types, we do not need to worry about shape and can return
|
||||||
|
// early.
|
||||||
|
if (!a_tt && !b_tt) {
|
||||||
|
// Remove ref types.
|
||||||
|
if (may_ignore_ref_type_a) {
|
||||||
|
if (auto ref_type = a.dyn_cast<mlir::TF::TensorFlowRefType>()) {
|
||||||
|
a = ref_type.RemoveRef();
|
||||||
|
if (a == b) return a;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (a.getKind() != b.getKind()) return nullptr;
|
||||||
|
|
||||||
|
// If either is not a type that contain subtypes then the types are not cast
|
||||||
|
// compatible.
|
||||||
|
auto a_wst = a.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
|
||||||
|
auto b_wst = b.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
|
||||||
|
if (!a_wst || !b_wst) return nullptr;
|
||||||
|
|
||||||
|
// For Variant types we are more permissive right now and accept all pairs
|
||||||
|
// of Variant types. If we are more constrainted and check compatibility of
|
||||||
|
// subtypes, we might reject valid graphs.
|
||||||
|
// TODO(prakalps): Variant doesn't have a subtype, we assign it
|
||||||
|
// one, so we should only assign it one when we know the subtype. Then we
|
||||||
|
// can be more constrained and check subtypes for cast compatibility as
|
||||||
|
// well.
|
||||||
|
if (a.isa<mlir::TF::VariantType>()) return a;
|
||||||
|
|
||||||
|
// For Resource types, we recursively check the subtypes for cast
|
||||||
|
// compatibility, if possible. Otherwise treat them as compatible.
|
||||||
|
auto a_wst_st = a_wst.GetSubtypes();
|
||||||
|
auto b_wst_st = b_wst.GetSubtypes();
|
||||||
|
if (a_wst_st.empty() || b_wst_st.empty()) return a;
|
||||||
|
if (a_wst_st.size() != b_wst_st.size()) return nullptr;
|
||||||
|
llvm::SmallVector<mlir::TensorType, 4> refined_subtypes;
|
||||||
|
for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) {
|
||||||
|
mlir::Type refined_st =
|
||||||
|
GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
|
||||||
|
/*may_ignore_ref_type_a=*/false);
|
||||||
|
if (!refined_st) return nullptr;
|
||||||
|
refined_subtypes.push_back(refined_st.cast<mlir::TensorType>());
|
||||||
|
}
|
||||||
|
|
||||||
|
return mlir::TF::ResourceType::get(refined_subtypes, a.getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
// For tensor types, check compatibility of both element type and shape.
|
||||||
|
mlir::Type refined_element_ty = GetCastCompatibleType(
|
||||||
|
a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
|
||||||
|
if (!refined_element_ty) return nullptr;
|
||||||
|
|
||||||
|
if (!a_tt.hasRank() && !b_tt.hasRank()) {
|
||||||
|
return mlir::UnrankedTensorType::get(refined_element_ty);
|
||||||
|
}
|
||||||
|
if (!a_tt.hasRank()) {
|
||||||
|
return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty);
|
||||||
|
}
|
||||||
|
if (!b_tt.hasRank()) {
|
||||||
|
return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t, 8> refined_shape;
|
||||||
|
if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
return mlir::RankedTensorType::get(refined_shape, refined_element_ty);
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
@ -224,44 +352,16 @@ bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
|
|||||||
|
|
||||||
bool HasCompatibleElementTypes(Type lhs, Type rhs,
|
bool HasCompatibleElementTypes(Type lhs, Type rhs,
|
||||||
bool may_ignore_ref_type_lhs) {
|
bool may_ignore_ref_type_lhs) {
|
||||||
// Fast path if everything is equal.
|
return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
|
||||||
if (lhs == rhs) return true;
|
}
|
||||||
|
|
||||||
// In TF all values are tensors.
|
bool AreCastCompatible(ArrayRef<Type> types) {
|
||||||
auto lhs_tt = lhs.cast<TensorType>();
|
Type common = types.front();
|
||||||
auto rhs_tt = rhs.cast<TensorType>();
|
for (auto type : types.drop_front()) {
|
||||||
|
Type refined_type =
|
||||||
// Verify matching element types. These should be identical dynamically,
|
GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
|
||||||
// so this allows for types not yet fully refined.
|
if (!refined_type) return false;
|
||||||
auto lhs_et = lhs_tt.getElementType();
|
common = refined_type;
|
||||||
auto rhs_et = rhs_tt.getElementType();
|
|
||||||
if (lhs_et == rhs_et) return true;
|
|
||||||
|
|
||||||
// Remove ref types.
|
|
||||||
if (may_ignore_ref_type_lhs) {
|
|
||||||
if (auto ref_type = lhs_et.dyn_cast<TF::TensorFlowRefType>()) {
|
|
||||||
lhs_et = ref_type.RemoveRef();
|
|
||||||
if (lhs_et == rhs_et) return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (lhs_et.getKind() != rhs_et.getKind()) return false;
|
|
||||||
|
|
||||||
// If either is not type that contain subtypes then the element types don't
|
|
||||||
// match.
|
|
||||||
auto lhs_wst = lhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
|
||||||
auto rhs_wst = rhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
|
||||||
if (!lhs_wst || !rhs_wst) return false;
|
|
||||||
|
|
||||||
// Consider the subtype recursively.
|
|
||||||
auto lhs_wst_st = lhs_wst.GetSubtypes();
|
|
||||||
auto rhs_wst_st = rhs_wst.GetSubtypes();
|
|
||||||
if (lhs_wst_st.empty() || rhs_wst_st.empty()) return true;
|
|
||||||
if (lhs_wst_st.size() != rhs_wst_st.size()) return false;
|
|
||||||
for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) {
|
|
||||||
if (!HasCompatibleElementTypes(std::get<0>(subtypes),
|
|
||||||
std::get<1>(subtypes)))
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -313,6 +313,12 @@ bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs);
|
|||||||
bool HasCompatibleElementTypes(Type lhs, Type rhs,
|
bool HasCompatibleElementTypes(Type lhs, Type rhs,
|
||||||
bool may_ignore_ref_type_lhs = false);
|
bool may_ignore_ref_type_lhs = false);
|
||||||
|
|
||||||
|
// Returns true if all TensorFlow types can be cast to one
|
||||||
|
// another. In other words, a single run-time value is legal for both the types.
|
||||||
|
// For example, tensor<*xf32>, tensor<?xf32> and tensor<3xf32> are cast
|
||||||
|
// compatible.
|
||||||
|
bool AreCastCompatible(ArrayRef<Type> types);
|
||||||
|
|
||||||
} // end namespace TF
|
} // end namespace TF
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
|
@ -881,20 +881,29 @@ func @testValidMatrixBandPartOpUnranked(%arg0: tensor<*xbf16>, %arg1: tensor<i64
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Test invalid tf.MatrixBandPart
|
// Test valid tf.MatrixBandPart
|
||||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
|
// CHECK-LABEL: func @testValidMatrixBandPartOpUnrankedBand
|
||||||
// expected-error @+1 {{op failed to verify that all of {input, band} have same type}}
|
func @testValidMatrixBandPartOpUnrankedBand(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
|
||||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
|
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
|
||||||
return %0 : tensor<64x64xbf16>
|
return %0 : tensor<*xbf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test valid tf.MatrixBandPart
|
||||||
|
// CHECK-LABEL: func @testValidMatrixBandPartOpCompatibleDynamicShapes
|
||||||
|
func @testValidMatrixBandPartOpCompatibleDynamicShapes(%arg0: tensor<?x10x?xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<?x?x8xbf16> {
|
||||||
|
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<?x10x?xbf16>, tensor<i64>, tensor<i64>) -> tensor<?x?x8xbf16>
|
||||||
|
return %0 : tensor<?x?x8xbf16>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Test invalid tf.MatrixBandPart
|
// Test invalid tf.MatrixBandPart
|
||||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
|
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
|
||||||
// expected-error @+1 {{op failed to verify that all of {input, band} have same type}}
|
// expected-error @+1 {{op failed to verify that all of {input, band} have dynamically equal types}}
|
||||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
|
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
|
||||||
return %0 : tensor<*xbf16>
|
return %0 : tensor<64x64xbf16>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -1,13 +1,17 @@
|
|||||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-extract-head-tail-outside-compilation | FileCheck %s --dump-input-on-failure
|
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-extract-head-tail-outside-compilation | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
// Tests extraction of a single outside compiled cluster with no input or output dependecies.
|
// Tests extraction of a outside compiled ops at head of TPU computation.
|
||||||
|
|
||||||
// CHECK-LABEL: func @nodep_single_head_outside_compilation
|
func @single_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
||||||
func @nodep_single_head_outside_compilation() -> () {
|
// CHECK: tf_device.launch
|
||||||
// CHECK: "tf.A"
|
// CHECK: "tf.A"
|
||||||
// CHECK-NEXT: "tf_device.launch"
|
// CHECK-NEXT: tf_device.return
|
||||||
"tf_device.launch"() ( {
|
//
|
||||||
"tf.A"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
// CHECK: "tf_device.cluster"
|
||||||
|
// CHECK: "tf.C"
|
||||||
|
// CHECK-NEXT: tf_device.return
|
||||||
|
"tf_device.cluster"() ( {
|
||||||
|
"tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
|
||||||
"tf.B"() : () -> ()
|
"tf.B"() : () -> ()
|
||||||
"tf.C"() : () -> ()
|
"tf.C"() : () -> ()
|
||||||
tf_device.return
|
tf_device.return
|
||||||
@ -15,15 +19,62 @@ func @nodep_single_head_outside_compilation() -> () {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @nodep_multiple_head_outside_compilation
|
// CHECK-LABEL: func @multiple_head_outside_compilation
|
||||||
func @nodep_multiple_head_outside_compilation() -> () {
|
func @multiple_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
||||||
// CHECK: "tf.A"
|
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||||
// CHECK-NEXT: "tf.B"
|
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||||
// CHECK-NEXT: "tf_device.launch"
|
// CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
|
||||||
"tf_device.launch"() ( {
|
// CHECK: "tf.C"
|
||||||
"tf.A"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
// CHECK-NEXT: tf_device.return %[[B_OUT]]
|
||||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
//
|
||||||
"tf.C"() : () -> ()
|
// CHECK: "tf_device.cluster"
|
||||||
|
// CHECK: "tf.D"(%[[LAUNCH_OUT]])
|
||||||
|
// CHECK-NEXT: tf_device.return
|
||||||
|
"tf_device.cluster"() ( {
|
||||||
|
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||||
|
%1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||||
|
"tf.C"(%1, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> ()
|
||||||
|
"tf.D"(%1) : (tensor<i32>) -> ()
|
||||||
|
tf_device.return
|
||||||
|
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @test_do_not_outside_compiled_ops_in_middle
|
||||||
|
func @test_do_not_outside_compiled_ops_in_middle(%arg0 : tensor<i32>) -> () {
|
||||||
|
// CHECK-NOT: tf_device.launch
|
||||||
|
// CHECK: "tf_device.cluster"
|
||||||
|
// CHECK-NEXT: "tf.A"
|
||||||
|
// CHECK-NEXT: "tf.B"
|
||||||
|
// CHECK-NEXT: "tf.C"
|
||||||
|
// CHECK-NEXT: tf_device.return
|
||||||
|
"tf_device.cluster"() ( {
|
||||||
|
%0 = "tf.A"(%arg0) {} : (tensor<i32>) -> (tensor<i32>)
|
||||||
|
%1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
|
||||||
|
"tf.C"(%1) : (tensor<i32>) -> ()
|
||||||
|
tf_device.return
|
||||||
|
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @test_ops_with_tpu_operands_not_extracted
|
||||||
|
func @test_ops_with_tpu_operands_not_extracted(%arg0 : tensor<i32>) -> () {
|
||||||
|
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||||
|
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||||
|
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
|
||||||
|
// CHECK-NEXT: tf_device.return %[[D_OUT]]
|
||||||
|
//
|
||||||
|
// CHECK: "tf_device.cluster"
|
||||||
|
// CHECK: "tf.B"
|
||||||
|
// CHECK: "tf.C"
|
||||||
|
// CHECK: "tf.E"
|
||||||
|
// CHECK-NEXT: tf_device.return
|
||||||
|
"tf_device.cluster"() ( {
|
||||||
|
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||||
|
%1 = "tf.B"() {} : () -> (tensor<i32>)
|
||||||
|
%2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
|
||||||
|
%3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
|
||||||
|
%4 = "tf.E"(%3) {} : (tensor<i32>) -> (tensor<i32>)
|
||||||
tf_device.return
|
tf_device.return
|
||||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||||
return
|
return
|
||||||
|
@ -3,12 +3,12 @@
|
|||||||
// Tests that missing `_xla_outside_compilation` attribute value results in an error.
|
// Tests that missing `_xla_outside_compilation` attribute value results in an error.
|
||||||
|
|
||||||
func @missing_outside_compilation_attribute() -> () {
|
func @missing_outside_compilation_attribute() -> () {
|
||||||
"tf_device.launch"() ( {
|
"tf_device.cluster"() ( {
|
||||||
"tf.A"() : () -> ()
|
"tf.A"() : () -> ()
|
||||||
// expected-error@+1 {{attribute '_xla_outside_compilation' is empty}}
|
// expected-error@+1 {{attribute '_xla_outside_compilation' is empty}}
|
||||||
"tf.B"() {_xla_outside_compilation = ""} : () -> ()
|
"tf.B"() {_xla_outside_compilation = ""} : () -> ()
|
||||||
tf_device.return
|
tf_device.return
|
||||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -18,11 +18,11 @@ func @missing_outside_compilation_attribute() -> () {
|
|||||||
|
|
||||||
// CHECK-LABEL: func @no_outside_compilation
|
// CHECK-LABEL: func @no_outside_compilation
|
||||||
func @no_outside_compilation() -> tensor<?xi32> {
|
func @no_outside_compilation() -> tensor<?xi32> {
|
||||||
%0 = "tf_device.launch"() ( {
|
%0 = "tf_device.cluster"() ( {
|
||||||
%1 = "tf.A"() : () -> tensor<?xi32>
|
%1 = "tf.A"() : () -> tensor<?xi32>
|
||||||
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
tf_device.return %2 : tensor<?xi32>
|
tf_device.return %2 : tensor<?xi32>
|
||||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<?xi32>
|
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||||
return %0 : tensor<?xi32>
|
return %0 : tensor<?xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -36,16 +36,15 @@ func @nodep_single_outside_compilation() -> () {
|
|||||||
// CHECK-NEXT: "tf_device.launch"
|
// CHECK-NEXT: "tf_device.launch"
|
||||||
// CHECK-NEXT: "tf.B"
|
// CHECK-NEXT: "tf.B"
|
||||||
// CHECK-NOT: _xla_outside_compilation
|
// CHECK-NOT: _xla_outside_compilation
|
||||||
// CHECK: "tf_device.launch"
|
// CHECK: "tf_device.cluster"
|
||||||
// CHECK-NEXT: "tf.A"
|
// CHECK-NEXT: "tf.A"
|
||||||
// CHECK: device = "tpu0"
|
// CHECK: cluster_attr = "cluster_attr"
|
||||||
// CHECK-SAME: launch_attr = "launch_attr"
|
"tf_device.cluster"() ( {
|
||||||
"tf_device.launch"() ( {
|
|
||||||
"tf.A"() : () -> ()
|
"tf.A"() : () -> ()
|
||||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||||
"tf.C"() : () -> ()
|
"tf.C"() : () -> ()
|
||||||
tf_device.return
|
tf_device.return
|
||||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,19 +58,18 @@ func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
|
|||||||
// CHECK-NEXT: "tf.C"
|
// CHECK-NEXT: "tf.C"
|
||||||
// CHECK-NEXT: "tf.D"
|
// CHECK-NEXT: "tf.D"
|
||||||
// CHECK-NOT: _xla_outside_compilation
|
// CHECK-NOT: _xla_outside_compilation
|
||||||
// CHECK: "tf_device.launch"
|
// CHECK: "tf_device.cluster"
|
||||||
// CHECK-NEXT: "tf.A"
|
// CHECK-NEXT: "tf.A"
|
||||||
// CHECK-NEXT: "tf.E"
|
// CHECK-NEXT: "tf.E"
|
||||||
// CHECK: device = "tpu0"
|
// CHECK: cluster_attr = "cluster_attr"
|
||||||
// CHECK-SAME: launch_attr = "launch_attr"
|
"tf_device.cluster"() ( {
|
||||||
"tf_device.launch"() ( {
|
|
||||||
"tf.A"() : () -> ()
|
"tf.A"() : () -> ()
|
||||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||||
"tf.C"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
"tf.C"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||||
"tf.D"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
"tf.D"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||||
"tf.E"() : () -> ()
|
"tf.E"() : () -> ()
|
||||||
tf_device.return
|
tf_device.return
|
||||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,15 +78,16 @@ func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
|
|||||||
// CHECK-LABEL: func @nodep_multiple_outside_compilation
|
// CHECK-LABEL: func @nodep_multiple_outside_compilation
|
||||||
func @nodep_multiple_outside_compilation() -> () {
|
func @nodep_multiple_outside_compilation() -> () {
|
||||||
// CHECK: "tf_device.parallel_execute"
|
// CHECK: "tf_device.parallel_execute"
|
||||||
// CHECK-COUNT-3: "tf_device.launch"
|
// CHECK-COUNT-2: "tf_device.launch"
|
||||||
"tf_device.launch"() ( {
|
// CHECK: "tf_device.cluster"
|
||||||
|
"tf_device.cluster"() ( {
|
||||||
"tf.A"() : () -> ()
|
"tf.A"() : () -> ()
|
||||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||||
"tf.C"() : () -> ()
|
"tf.C"() : () -> ()
|
||||||
"tf.D"() {_xla_outside_compilation = "cluster2"} : () -> ()
|
"tf.D"() {_xla_outside_compilation = "cluster2"} : () -> ()
|
||||||
"tf.E"() : () -> ()
|
"tf.E"() : () -> ()
|
||||||
tf_device.return
|
tf_device.return
|
||||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,17 +99,17 @@ func @single_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tens
|
|||||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||||
// CHECK-NEXT: "tf_device.launch"
|
// CHECK-NEXT: "tf_device.launch"
|
||||||
// CHECK: %[[TPU_LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch"
|
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster"
|
||||||
// CHECK: tf_device.return
|
// CHECK: tf_device.return
|
||||||
// CHECK: tf_device.return %[[TPU_LAUNCH_OUTPUT]]
|
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
|
||||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
||||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||||
%2 = "tf_device.launch"() ( {
|
%2 = "tf_device.cluster"() ( {
|
||||||
"tf.A"() : () -> ()
|
"tf.A"() : () -> ()
|
||||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||||
%3 = "tf.C"() : () -> tensor<?xi32>
|
%3 = "tf.C"() : () -> tensor<?xi32>
|
||||||
tf_device.return %3 : tensor<?xi32>
|
tf_device.return %3 : tensor<?xi32>
|
||||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<?xi32>
|
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||||
tf_device.return %2 : tensor<?xi32>
|
tf_device.return %2 : tensor<?xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,17 +124,17 @@ func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> te
|
|||||||
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
|
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
|
||||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
|
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
|
||||||
// CHECK-NEXT: "tf_device.launch"
|
// CHECK-NEXT: "tf_device.launch"
|
||||||
// CHECK: %[[TPU_LAUNCH_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]]:2 = "tf_device.cluster"
|
||||||
// CHECK: tf_device.return
|
// CHECK: tf_device.return
|
||||||
// CHECK: tf_device.return %[[TPU_LAUNCH_OUTPUT]]
|
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
|
||||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
||||||
%1:4 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
%1:4 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||||
%2, %3 = "tf_device.launch"() ( {
|
%2, %3 = "tf_device.cluster"() ( {
|
||||||
%4 = "tf.A"() : () -> tensor<?xf32>
|
%4 = "tf.A"() : () -> tensor<?xf32>
|
||||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||||
%5 = "tf.C"() : () -> tensor<?xi32>
|
%5 = "tf.C"() : () -> tensor<?xi32>
|
||||||
tf_device.return %4, %5 : tensor<?xf32>, tensor<?xi32>
|
tf_device.return %4, %5 : tensor<?xf32>, tensor<?xi32>
|
||||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> (tensor<?xf32>, tensor<?xi32>)
|
}) {cluster_attr = "cluster_attr"} : () -> (tensor<?xf32>, tensor<?xi32>)
|
||||||
tf_device.return %2, %3 : tensor<?xf32>, tensor<?xi32>
|
tf_device.return %2, %3 : tensor<?xf32>, tensor<?xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1222,6 +1222,41 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Tests simple case of `tf_device.cluster_func` on TPU with replication and parallel_execute.
|
||||||
|
|
||||||
|
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
|
||||||
|
// CHECK-LABEL: func @replicated_parallel_tpu_cluster_func
|
||||||
|
func @replicated_parallel_tpu_cluster_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||||
|
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||||
|
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||||
|
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||||
|
// CHECK: "tf._TPUCompileMlir"
|
||||||
|
// CHECK: "tf.TPUCompileSucceededAssert"
|
||||||
|
// CHECK: "tf_device.parallel_execute"
|
||||||
|
// CHECK: "tf.TPUExecute"
|
||||||
|
%3 = "tf_device.parallel_execute"() ( {
|
||||||
|
"tf.D"() : () -> ()
|
||||||
|
tf_device.return
|
||||||
|
}, {
|
||||||
|
%4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
|
||||||
|
tf_device.return %4 : tensor<?xi32>
|
||||||
|
}) : () -> (tensor<?xi32>)
|
||||||
|
tf_device.return %3 : tensor<?xi32>
|
||||||
|
}
|
||||||
|
%2 = "tf.C"(%1#1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
return %2 : tensor<?xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||||
|
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
return %0 : tensor<?xi32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// Tests devices are set properly for non replicated model parallelism.
|
// Tests devices are set properly for non replicated model parallelism.
|
||||||
|
|
||||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
|
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
|
||||||
|
@ -258,7 +258,7 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass();
|
|||||||
|
|
||||||
// Creates a pass that extracts outside compilation (CPU ops inside TPU cluster)
|
// Creates a pass that extracts outside compilation (CPU ops inside TPU cluster)
|
||||||
// at head/tail of TPU cluster to run before/after TPU computation.
|
// at head/tail of TPU cluster to run before/after TPU computation.
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
CreateTPUExtractHeadTailOutsideCompilationPass();
|
CreateTPUExtractHeadTailOutsideCompilationPass();
|
||||||
|
|
||||||
// Creates a pass that extract outside compilation (CPU ops inside TPU cluster)
|
// Creates a pass that extract outside compilation (CPU ops inside TPU cluster)
|
||||||
|
@ -66,8 +66,7 @@ using tensorflow::shape_inference::ShapeHandle;
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TF {
|
namespace TF {
|
||||||
namespace {
|
namespace {
|
||||||
Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType(
|
Optional<SmallVector<Type, 4>> InferShapeForFunctionReturnType(FuncOp func) {
|
||||||
FuncOp func) {
|
|
||||||
// Find any return ops.
|
// Find any return ops.
|
||||||
SmallVector<ReturnOp, 4> return_ops;
|
SmallVector<ReturnOp, 4> return_ops;
|
||||||
for (Block& block : func) {
|
for (Block& block : func) {
|
||||||
@ -137,9 +136,9 @@ void AddCastBackForUnsupportedNonTFUses(Operation* op, Value result,
|
|||||||
cast_op = b.create<TF::CastOp>(op->getLoc(), old_type, result,
|
cast_op = b.create<TF::CastOp>(op->getLoc(), old_type, result,
|
||||||
/*truncate=*/b.getBoolAttr(false));
|
/*truncate=*/b.getBoolAttr(false));
|
||||||
}
|
}
|
||||||
return mlir::Value(cast_op);
|
return Value(cast_op);
|
||||||
};
|
};
|
||||||
for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) {
|
for (OpOperand& use : make_early_inc_range(result.getUses())) {
|
||||||
if (use.getOwner()->getDialect() != tf_dialect &&
|
if (use.getOwner()->getDialect() != tf_dialect &&
|
||||||
!IsSupportedNonTFOp(use.getOwner()))
|
!IsSupportedNonTFOp(use.getOwner()))
|
||||||
use.set(get_cast_op());
|
use.set(get_cast_op());
|
||||||
@ -162,7 +161,7 @@ Optional<tensorflow::PartialTensorShape> GetShapeFromMlirType(Type t) {
|
|||||||
bool InferShapeForPassThroughOps(OperandRange pass_through_operands,
|
bool InferShapeForPassThroughOps(OperandRange pass_through_operands,
|
||||||
Operation* op, Dialect* tf_dialect) {
|
Operation* op, Dialect* tf_dialect) {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
for (auto entry : llvm::zip(pass_through_operands, op->getResults())) {
|
for (auto entry : zip(pass_through_operands, op->getResults())) {
|
||||||
Type operand_type = std::get<0>(entry).getType();
|
Type operand_type = std::get<0>(entry).getType();
|
||||||
Value result = std::get<1>(entry);
|
Value result = std::get<1>(entry);
|
||||||
if (result.getType() == operand_type) continue;
|
if (result.getType() == operand_type) continue;
|
||||||
@ -204,7 +203,7 @@ bool InferShapeForNonTFDialectOperation(Operation* op, Dialect* tf_dialect) {
|
|||||||
tf_dialect);
|
tf_dialect);
|
||||||
}
|
}
|
||||||
// TODO(b/155227679): Use OpInterface instead of hard-coding for TensorCastOp.
|
// TODO(b/155227679): Use OpInterface instead of hard-coding for TensorCastOp.
|
||||||
if (auto tensor_cast = dyn_cast<mlir::TensorCastOp>(op)) {
|
if (auto tensor_cast = dyn_cast<TensorCastOp>(op)) {
|
||||||
return InferShapeForPassThroughOps(
|
return InferShapeForPassThroughOps(
|
||||||
tensor_cast.getOperation()->getOperands(), op, tf_dialect);
|
tensor_cast.getOperation()->getOperands(), op, tf_dialect);
|
||||||
}
|
}
|
||||||
@ -254,7 +253,7 @@ GetSubtypes(Type type) {
|
|||||||
// match the i-th operand type). Returns true if anything is changed.
|
// match the i-th operand type). Returns true if anything is changed.
|
||||||
bool PassThroughOperandTypes(OperandRange operands, ResultRange results) {
|
bool PassThroughOperandTypes(OperandRange operands, ResultRange results) {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
for (auto entry : llvm::zip(operands, results)) {
|
for (auto entry : zip(operands, results)) {
|
||||||
Type operand_type = std::get<0>(entry).getType();
|
Type operand_type = std::get<0>(entry).getType();
|
||||||
Type result_type = std::get<1>(entry).getType();
|
Type result_type = std::get<1>(entry).getType();
|
||||||
if (operand_type == result_type) continue;
|
if (operand_type == result_type) continue;
|
||||||
@ -291,14 +290,13 @@ bool InferShapeForCall(Operation* op) {
|
|||||||
CallInterfaceCallable callable = call_op.getCallableForCallee();
|
CallInterfaceCallable callable = call_op.getCallableForCallee();
|
||||||
SymbolRefAttr sym = callable.dyn_cast<SymbolRefAttr>();
|
SymbolRefAttr sym = callable.dyn_cast<SymbolRefAttr>();
|
||||||
if (!sym) return false;
|
if (!sym) return false;
|
||||||
FuncOp func =
|
FuncOp func = dyn_cast<FuncOp>(SymbolTable::lookupNearestSymbolFrom(op, sym));
|
||||||
dyn_cast<mlir::FuncOp>(SymbolTable::lookupNearestSymbolFrom(op, sym));
|
|
||||||
if (!func) return false;
|
if (!func) return false;
|
||||||
|
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
// Map each of the results of the call to the returned type of the
|
// Map each of the results of the call to the returned type of the
|
||||||
// function.
|
// function.
|
||||||
for (auto result : llvm::zip(op->getResults(), func.getType().getResults())) {
|
for (auto result : zip(op->getResults(), func.getType().getResults())) {
|
||||||
if (std::get<0>(result).getType() == std::get<1>(result)) continue;
|
if (std::get<0>(result).getType() == std::get<1>(result)) continue;
|
||||||
// Skip already statically shaped results.
|
// Skip already statically shaped results.
|
||||||
if (!CanBeRefined(std::get<0>(result).getType())) continue;
|
if (!CanBeRefined(std::get<0>(result).getType())) continue;
|
||||||
@ -323,8 +321,8 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti,
|
|||||||
Operation* op = infer_ti.getOperation();
|
Operation* op = infer_ti.getOperation();
|
||||||
SmallVector<Type, 4> inferred;
|
SmallVector<Type, 4> inferred;
|
||||||
LogicalResult res = infer_ti.inferReturnTypes(
|
LogicalResult res = infer_ti.inferReturnTypes(
|
||||||
op->getContext(), op->getLoc(), op->getOperands(), op->getAttrs(),
|
op->getContext(), op->getLoc(), op->getOperands(),
|
||||||
op->getRegions(), inferred);
|
op->getAttrDictionary(), op->getRegions(), inferred);
|
||||||
if (failed(res)) {
|
if (failed(res)) {
|
||||||
op->emitOpError("failed to refine type as inference failed");
|
op->emitOpError("failed to refine type as inference failed");
|
||||||
return false;
|
return false;
|
||||||
@ -335,7 +333,7 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti,
|
|||||||
// Map each of the results of the call to the returned type of the
|
// Map each of the results of the call to the returned type of the
|
||||||
// function.
|
// function.
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
for (auto result : llvm::zip(op->getResults(), inferred)) {
|
for (auto result : zip(op->getResults(), inferred)) {
|
||||||
if (std::get<0>(result).getType() == std::get<1>(result)) continue;
|
if (std::get<0>(result).getType() == std::get<1>(result)) continue;
|
||||||
|
|
||||||
// Inserts a cast back to the original type if any user is not in the
|
// Inserts a cast back to the original type if any user is not in the
|
||||||
@ -356,7 +354,7 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti,
|
|||||||
// so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output
|
// so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output
|
||||||
// scalar value).
|
// scalar value).
|
||||||
struct ValuePort {
|
struct ValuePort {
|
||||||
llvm::PointerUnion<Operation*, BlockArgument> producer;
|
PointerUnion<Operation*, BlockArgument> producer;
|
||||||
SmallVector<unsigned int, 2> port;
|
SmallVector<unsigned int, 2> port;
|
||||||
|
|
||||||
bool operator==(const ValuePort& other) const {
|
bool operator==(const ValuePort& other) const {
|
||||||
@ -374,37 +372,36 @@ struct ValuePort {
|
|||||||
port = {0};
|
port = {0};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ValuePort(llvm::PointerUnion<Operation*, BlockArgument> producer,
|
ValuePort(PointerUnion<Operation*, BlockArgument> producer,
|
||||||
SmallVector<unsigned int, 2> port)
|
SmallVector<unsigned int, 2> port)
|
||||||
: producer(producer), port(port) {}
|
: producer(producer), port(port) {}
|
||||||
|
|
||||||
llvm::raw_ostream& print(llvm::raw_ostream& os) const {
|
raw_ostream& print(raw_ostream& os) const {
|
||||||
if (auto* op = producer.dyn_cast<Operation*>())
|
if (auto* op = producer.dyn_cast<Operation*>())
|
||||||
os << "op " << op->getName();
|
os << "op " << op->getName();
|
||||||
if (auto ba = producer.dyn_cast<BlockArgument>())
|
if (auto ba = producer.dyn_cast<BlockArgument>())
|
||||||
os << "block_arg " << ba.getArgNumber();
|
os << "block_arg " << ba.getArgNumber();
|
||||||
os << llvm::formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
|
os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ValuePortHasher {
|
struct ValuePortHasher {
|
||||||
std::size_t operator()(const ValuePort& other) const {
|
std::size_t operator()(const ValuePort& other) const {
|
||||||
return llvm::hash_combine(
|
return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()),
|
||||||
llvm::hash_value(other.producer.getOpaqueValue()),
|
hash_value(ArrayRef<unsigned int>(other.port)));
|
||||||
llvm::hash_value(ArrayRef<unsigned int>(other.port)));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
using ValuePortResultMap =
|
using ValuePortResultMap =
|
||||||
std::unordered_map<ValuePort, Attribute, ValuePortHasher>;
|
std::unordered_map<ValuePort, Attribute, ValuePortHasher>;
|
||||||
using ComputedQueryFn = llvm::function_ref<bool(ValuePort)>;
|
using ComputedQueryFn = function_ref<bool(ValuePort)>;
|
||||||
using ValueQueryFn = llvm::function_ref<Attribute(const ValuePort&)>;
|
using ValueQueryFn = function_ref<Attribute(const ValuePort&)>;
|
||||||
using ValuePortInputs = llvm::SmallVectorImpl<ValuePort>;
|
using ValuePortInputs = SmallVectorImpl<ValuePort>;
|
||||||
|
|
||||||
// TODO(jpienaar): InputsRequiredForOutput and ComputeOutputComponent are
|
// TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are
|
||||||
// intended to be switched to op interfaces once more refined.
|
// intended to be switched to op interfaces once more refined.
|
||||||
LogicalResult InputsRequiredForOutput(ValuePort value_port,
|
LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
|
||||||
ComputedQueryFn has_been_computed,
|
ComputedQueryFn has_been_computed,
|
||||||
ValuePortInputs* inputs) {
|
ValuePortInputs* inputs) {
|
||||||
auto op = value_port.producer.dyn_cast<Operation*>();
|
auto op = value_port.producer.dyn_cast<Operation*>();
|
||||||
@ -460,26 +457,94 @@ Attribute ComputeOutputComponent(const ValuePort& value_port,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
|
// Context used during ShapeInference. This class contains common information
|
||||||
|
// that is required by the individual shape inference helper functions (e.g.,
|
||||||
|
// TF Graph version, constant values computed, etc.)
|
||||||
|
class ShapeInference {
|
||||||
|
public:
|
||||||
|
ShapeInference(int64_t graph_version, MLIRContext* context);
|
||||||
|
|
||||||
|
LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
|
||||||
|
ValuePortInputs* inputs) {
|
||||||
|
return ::mlir::TF::ComputeInputsRequiredForOutput(
|
||||||
|
value_port,
|
||||||
|
[this](const ValuePort& port) {
|
||||||
|
return results_.find(port) != results_.end();
|
||||||
|
},
|
||||||
|
inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
Attribute ComputeOutputComponent(const ValuePort& value_port) {
|
||||||
|
return ::mlir::TF::ComputeOutputComponent(
|
||||||
|
value_port, [this](const ValuePort& port) { return results_[port]; });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns ShapeHandle if the op result could be computed as shape.
|
||||||
|
ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic);
|
||||||
|
|
||||||
|
void RecordValue(const ValuePort& value_port, Attribute value) {
|
||||||
|
results_[value_port] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performs shape inference on the provided op and return true if the type of
|
||||||
|
// at least one result has been changed.
|
||||||
|
// A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
|
||||||
|
// `graph_version` indicates the current GraphDef compatibility versions
|
||||||
|
// (the versions field in graph.proto).
|
||||||
|
bool InferShapeForSingleOperation(Operation* op);
|
||||||
|
|
||||||
|
// Infers shape on the provided region, including nested ones, iterate until
|
||||||
|
// fix point with a limit of max_iteration. Returns success if fix point is
|
||||||
|
// reached before max_iteration.
|
||||||
|
LogicalResult InferShapeUntilFixPoint(Region* region,
|
||||||
|
int64_t max_iteration = 10);
|
||||||
|
|
||||||
|
// Updates input types and refine shapes inside body of functions that are
|
||||||
|
// attached to ControlFlow ops (If/While). These functions include Then/Else
|
||||||
|
// branches of IfOp and Cond/Body functions of WhileOp. These functions share
|
||||||
|
// following common properties:
|
||||||
|
// 1) They are never reused, ie. having a single use in module.
|
||||||
|
// 2) Their input types match those of their parent ops (excluding inputs
|
||||||
|
// like predicate).
|
||||||
|
// Returns a boolean indicating whether any change has been applied.
|
||||||
|
LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
|
||||||
|
ArrayRef<Type> input_types,
|
||||||
|
int64_t max_iteration);
|
||||||
|
|
||||||
|
// Propagate the shapes to the functions named.
|
||||||
|
LogicalResult PropagateShapeToFunctions(
|
||||||
|
ModuleOp module, Operation::operand_type_range input_types,
|
||||||
|
ArrayRef<StringRef> func_names, int64_t max_iteration);
|
||||||
|
|
||||||
|
// Shape propagation for call/control flow ops.
|
||||||
|
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
|
||||||
|
int64_t max_iteration);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Mapping between ValuePort (which corresponds to an OpResult or smaller,
|
||||||
|
// e.g., first element of OpResult produded) to an Attribute if the ValuePort
|
||||||
|
// corresponds to a constant value.
|
||||||
|
ValuePortResultMap results_;
|
||||||
|
int64_t graph_version_;
|
||||||
|
MLIRContext* context_;
|
||||||
|
Dialect* tf_dialect_;
|
||||||
|
};
|
||||||
|
|
||||||
|
ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context)
|
||||||
|
: graph_version_(graph_version) {
|
||||||
|
context_ = context;
|
||||||
|
tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
|
||||||
|
InferenceContext* ic) {
|
||||||
LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially "));
|
LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially "));
|
||||||
auto rt = result.getType().dyn_cast<RankedTensorType>();
|
auto rt = result.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {};
|
if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {};
|
||||||
int dim_size = rt.getDimSize(0);
|
int dim_size = rt.getDimSize(0);
|
||||||
|
|
||||||
// Worklist to direct partial evaluation.
|
// Worklist to direct partial evaluation.
|
||||||
llvm::SmallVector<ValuePort, 4> worklist;
|
SmallVector<ValuePort, 4> worklist;
|
||||||
// The ValuePort evaluated results.
|
|
||||||
// TODO(jpienaar): This could be cached across invocations (e.g., part of some
|
|
||||||
// inference context).
|
|
||||||
ValuePortResultMap evaluated;
|
|
||||||
// Returns whether a ValuePort has been previously computed.
|
|
||||||
auto has_been_computed = [&evaluated](const ValuePort& port) {
|
|
||||||
return evaluated.find(port) != evaluated.end();
|
|
||||||
};
|
|
||||||
// Returns previously computed ValuePort value.
|
|
||||||
auto values = [&evaluated](const ValuePort& port) -> Attribute {
|
|
||||||
return evaluated[port];
|
|
||||||
};
|
|
||||||
|
|
||||||
// Simple evaluator that attempts to partially evaluate the input value even
|
// Simple evaluator that attempts to partially evaluate the input value even
|
||||||
// if unable to evaluate the complete output. Below follows a simple stack
|
// if unable to evaluate the complete output. Below follows a simple stack
|
||||||
@ -498,7 +563,7 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
|
|||||||
LLVM_DEBUG(front.print(llvm::errs() << "\nWorklist front "));
|
LLVM_DEBUG(front.print(llvm::errs() << "\nWorklist front "));
|
||||||
|
|
||||||
SmallVector<ValuePort, 4> inputs;
|
SmallVector<ValuePort, 4> inputs;
|
||||||
auto res = InputsRequiredForOutput(front, has_been_computed, &inputs);
|
auto res = ComputeInputsRequiredForOutput(front, &inputs);
|
||||||
if (failed(res)) {
|
if (failed(res)) {
|
||||||
// Abort if unable to find which required inputs need to be computed.
|
// Abort if unable to find which required inputs need to be computed.
|
||||||
worklist.clear();
|
worklist.clear();
|
||||||
@ -513,16 +578,16 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ret = ComputeOutputComponent(front, values);
|
auto ret = ComputeOutputComponent(front);
|
||||||
if (!ret) continue;
|
if (!ret) continue;
|
||||||
|
|
||||||
evaluated[front] = ret;
|
RecordValue(front, ret);
|
||||||
LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
|
LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
|
||||||
|
|
||||||
// If worklist is empty, then this is the root query op.
|
// If worklist is empty, then this is the root query op.
|
||||||
if (worklist.empty()) {
|
if (worklist.empty()) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "[root node]\n");
|
LLVM_DEBUG(llvm::dbgs() << "[root node]\n");
|
||||||
if (auto dea = ret.dyn_cast<mlir::DenseIntElementsAttr>()) {
|
if (auto dea = ret.dyn_cast<DenseIntElementsAttr>()) {
|
||||||
if (dea.getNumElements() != 1) {
|
if (dea.getNumElements() != 1) {
|
||||||
LLVM_DEBUG(llvm::errs() << "Unexpected number of elements\n");
|
LLVM_DEBUG(llvm::errs() << "Unexpected number of elements\n");
|
||||||
return {};
|
return {};
|
||||||
@ -536,9 +601,8 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
|
|||||||
return ic->MakeShape(dims);
|
return ic->MakeShape(dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||||
int64_t graph_version) {
|
assert(tf_dialect_ == op->getDialect());
|
||||||
assert(tf_dialect == op->getDialect());
|
|
||||||
// The shape function of these ops sometimes does not propagate subtypes
|
// The shape function of these ops sometimes does not propagate subtypes
|
||||||
// (handle shapes) for resource and variant types. We use a simple passthrough
|
// (handle shapes) for resource and variant types. We use a simple passthrough
|
||||||
// to make sure they are preserved in the output.
|
// to make sure they are preserved in the output.
|
||||||
@ -550,7 +614,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
|||||||
// If no result for this op needs shape inference, we have a fast-path return.
|
// If no result for this op needs shape inference, we have a fast-path return.
|
||||||
// But if the type is a resource/variant, we do not skip it because we might
|
// But if the type is a resource/variant, we do not skip it because we might
|
||||||
// not have the handle shapes.
|
// not have the handle shapes.
|
||||||
if (llvm::none_of(op->getResultTypes(), CanBeRefined)) {
|
if (none_of(op->getResultTypes(), CanBeRefined)) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
|
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
|
||||||
<< op->getName() << "'.\n");
|
<< op->getName() << "'.\n");
|
||||||
return false;
|
return false;
|
||||||
@ -565,8 +629,8 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
|||||||
// This is necessary to avoid reprocessing the tf.Cast that are inserted at
|
// This is necessary to avoid reprocessing the tf.Cast that are inserted at
|
||||||
// the end of this function.
|
// the end of this function.
|
||||||
if (isa<CastOp>(op) &&
|
if (isa<CastOp>(op) &&
|
||||||
llvm::all_of(op->getResult(0).getUsers(), [&](Operation* user) {
|
all_of(op->getResult(0).getUsers(), [&](Operation* user) {
|
||||||
return user->getDialect() != tf_dialect;
|
return user->getDialect() != tf_dialect_;
|
||||||
})) {
|
})) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF "
|
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF "
|
||||||
"dialect operation users '"
|
"dialect operation users '"
|
||||||
@ -646,7 +710,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
|||||||
// Perform the shape inference using an InferenceContext with the input
|
// Perform the shape inference using an InferenceContext with the input
|
||||||
// shapes. This object is abstracting the information that the ShapeInference
|
// shapes. This object is abstracting the information that the ShapeInference
|
||||||
// function operates on.
|
// function operates on.
|
||||||
InferenceContext c(graph_version, *node_def, op_reg_data->op_def,
|
InferenceContext c(graph_version_, *node_def, op_reg_data->op_def,
|
||||||
input_shapes, input_tensors,
|
input_shapes, input_tensors,
|
||||||
/*input_tensors_as_shapes=*/{}, handle_shapes_and_types);
|
/*input_tensors_as_shapes=*/{}, handle_shapes_and_types);
|
||||||
auto status = c.Run(op_reg_data->shape_inference_fn);
|
auto status = c.Run(op_reg_data->shape_inference_fn);
|
||||||
@ -659,7 +723,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
|||||||
// Determine if, during shape computation, the shape functions attempted to
|
// Determine if, during shape computation, the shape functions attempted to
|
||||||
// query an input operand as shape where the input was not known/constant.
|
// query an input operand as shape where the input was not known/constant.
|
||||||
bool requires_inputs =
|
bool requires_inputs =
|
||||||
llvm::any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) {
|
any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) {
|
||||||
return c.requested_input_tensor_as_partial_shape(input) &&
|
return c.requested_input_tensor_as_partial_shape(input) &&
|
||||||
!input_tensors[input];
|
!input_tensors[input];
|
||||||
});
|
});
|
||||||
@ -723,7 +787,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
|||||||
new_element_type.isa<TF::VariantType>()) {
|
new_element_type.isa<TF::VariantType>()) {
|
||||||
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
|
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
|
||||||
if (handle_shapes_types) {
|
if (handle_shapes_types) {
|
||||||
llvm::SmallVector<mlir::TensorType, 1> subtypes;
|
SmallVector<TensorType, 1> subtypes;
|
||||||
OpBuilder b(op);
|
OpBuilder b(op);
|
||||||
for (const auto& shape_n_type : *handle_shapes_types) {
|
for (const auto& shape_n_type : *handle_shapes_types) {
|
||||||
Type element_type;
|
Type element_type;
|
||||||
@ -743,7 +807,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
|||||||
if (result.getType() == new_type) continue;
|
if (result.getType() == new_type) continue;
|
||||||
// Inserts a cast back to the original type if any user is not in the TF
|
// Inserts a cast back to the original type if any user is not in the TF
|
||||||
// dialect.
|
// dialect.
|
||||||
AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect,
|
AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect_,
|
||||||
result.getType());
|
result.getType());
|
||||||
// Finally we inferred the shape and replace the type for this result.
|
// Finally we inferred the shape and replace the type for this result.
|
||||||
result.setType(new_type);
|
result.setType(new_type);
|
||||||
@ -755,23 +819,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
|||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Updates input types and refine shapes inside body of functions that are
|
LogicalResult ShapeInference::RefineShapeForControlFlowFunc(
|
||||||
// attached to ControlFlow ops (If/While). These functions include Then/Else
|
FuncOp func, ArrayRef<Type> input_types, int64_t max_iteration) {
|
||||||
// branches of IfOp and Cond/Body functions of WhileOp. These functions share
|
|
||||||
// following common properties:
|
|
||||||
// 1) They are never reused, ie. having a single use in module.
|
|
||||||
// 2) Their input types match those of their parent ops (excluding inputs like
|
|
||||||
// predicate).
|
|
||||||
// Returns a boolean indicating whether any change has been applied.
|
|
||||||
LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
|
|
||||||
llvm::ArrayRef<Type> input_types,
|
|
||||||
int64_t graph_version,
|
|
||||||
int64_t max_iteration) {
|
|
||||||
ModuleOp module = func.getParentOfType<ModuleOp>();
|
ModuleOp module = func.getParentOfType<ModuleOp>();
|
||||||
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
|
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
|
||||||
int num_uses = std::distance(func_uses->begin(), func_uses->end());
|
int num_uses = std::distance(func_uses->begin(), func_uses->end());
|
||||||
if (num_uses != 1) {
|
if (num_uses != 1) {
|
||||||
func.emitWarning(llvm::formatv(
|
func.emitWarning(formatv(
|
||||||
"expected control flow function {0} to have exactly 1 use, found {1}.",
|
"expected control flow function {0} to have exactly 1 use, found {1}.",
|
||||||
func.getName(), num_uses));
|
func.getName(), num_uses));
|
||||||
return failure();
|
return failure();
|
||||||
@ -785,8 +839,7 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
|
|||||||
arg_and_idx.value().setType(input_types[arg_and_idx.index()]);
|
arg_and_idx.value().setType(input_types[arg_and_idx.index()]);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto res =
|
auto res = InferShapeUntilFixPoint(&func.getBody(), max_iteration);
|
||||||
InferShapeUntilFixPoint(&func.getBody(), graph_version, max_iteration);
|
|
||||||
if (failed(res)) return res;
|
if (failed(res)) return res;
|
||||||
|
|
||||||
auto new_return_types = InferShapeForFunctionReturnType(func);
|
auto new_return_types = InferShapeForFunctionReturnType(func);
|
||||||
@ -798,20 +851,18 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult PropagateShapeToFunctions(
|
LogicalResult ShapeInference::PropagateShapeToFunctions(
|
||||||
ModuleOp module, Operation::operand_type_range input_types,
|
ModuleOp module, Operation::operand_type_range input_types,
|
||||||
llvm::ArrayRef<StringRef> func_names, int64_t graph_version,
|
ArrayRef<StringRef> func_names, int64_t max_iteration) {
|
||||||
int64_t max_iteration) {
|
bool all_succeeded = true;
|
||||||
bool success = true;
|
|
||||||
auto types = llvm::to_vector<4>(input_types);
|
auto types = llvm::to_vector<4>(input_types);
|
||||||
for (auto func_name : func_names) {
|
for (auto func_name : func_names) {
|
||||||
FuncOp func = module.lookupSymbol<FuncOp>(func_name);
|
FuncOp func = module.lookupSymbol<FuncOp>(func_name);
|
||||||
if (failed(RefineShapeForControlFlowFunc(func, types, graph_version,
|
all_succeeded =
|
||||||
max_iteration))) {
|
succeeded(RefineShapeForControlFlowFunc(func, types, max_iteration)) &&
|
||||||
success = false;
|
all_succeeded;
|
||||||
}
|
}
|
||||||
}
|
return success(all_succeeded);
|
||||||
return mlir::success(success);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the callee has only one use, propagates any constant operand of call_op to
|
// If the callee has only one use, propagates any constant operand of call_op to
|
||||||
@ -831,7 +882,7 @@ void PropagateConstantToCallee(CallOpInterface call_op,
|
|||||||
// the constant inside the function.
|
// the constant inside the function.
|
||||||
for (auto arg : func.getArguments()) {
|
for (auto arg : func.getArguments()) {
|
||||||
auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp();
|
auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp();
|
||||||
if (llvm::isa_and_nonnull<TF::ConstOp>(operand)) {
|
if (isa_and_nonnull<TF::ConstOp>(operand)) {
|
||||||
arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0));
|
arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -850,33 +901,31 @@ void PropagateConstantFromCallee(CallOpInterface call_op,
|
|||||||
for (auto retval :
|
for (auto retval :
|
||||||
llvm::enumerate(func.front().getTerminator()->getOperands())) {
|
llvm::enumerate(func.front().getTerminator()->getOperands())) {
|
||||||
auto retval_op = retval.value().getDefiningOp();
|
auto retval_op = retval.value().getDefiningOp();
|
||||||
if (llvm::isa_and_nonnull<TF::ConstOp>(retval_op)) {
|
if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
|
||||||
op->getResult(retval.index())
|
op->getResult(retval.index())
|
||||||
.replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
|
.replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
|
LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
|
||||||
int64_t graph_version,
|
Operation* op, int64_t max_iteration) {
|
||||||
int64_t max_iteration) {
|
|
||||||
ModuleOp module = op->getParentOfType<ModuleOp>();
|
ModuleOp module = op->getParentOfType<ModuleOp>();
|
||||||
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
|
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
|
||||||
return PropagateShapeToFunctions(
|
return PropagateShapeToFunctions(
|
||||||
module, llvm::drop_begin(if_op.getOperandTypes(), 1),
|
module, drop_begin(if_op.getOperandTypes(), 1),
|
||||||
{if_op.then_branch(), if_op.else_branch()}, graph_version,
|
{if_op.then_branch(), if_op.else_branch()}, max_iteration);
|
||||||
max_iteration);
|
|
||||||
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
|
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
|
||||||
return PropagateShapeToFunctions(module, while_op.getOperandTypes(),
|
return PropagateShapeToFunctions(module, while_op.getOperandTypes(),
|
||||||
{while_op.cond(), while_op.body()},
|
{while_op.cond(), while_op.body()},
|
||||||
graph_version, max_iteration);
|
max_iteration);
|
||||||
} else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
|
} else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
|
||||||
CallInterfaceCallable callable = call_op.getCallableForCallee();
|
CallInterfaceCallable callable = call_op.getCallableForCallee();
|
||||||
if (SymbolRefAttr sym = callable.dyn_cast<SymbolRefAttr>()) {
|
if (SymbolRefAttr sym = callable.dyn_cast<SymbolRefAttr>()) {
|
||||||
PropagateConstantToCallee(call_op, sym, module);
|
PropagateConstantToCallee(call_op, sym, module);
|
||||||
if (failed(PropagateShapeToFunctions(
|
if (failed(PropagateShapeToFunctions(
|
||||||
module, call_op.getArgOperands().getTypes(),
|
module, call_op.getArgOperands().getTypes(),
|
||||||
{sym.getRootReference()}, graph_version, max_iteration))) {
|
{sym.getRootReference()}, max_iteration))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
PropagateConstantFromCallee(call_op, sym, module);
|
PropagateConstantFromCallee(call_op, sym, module);
|
||||||
@ -889,13 +938,10 @@ LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
|
||||||
int64_t max_iteration) {
|
int64_t max_iteration) {
|
||||||
MLIRContext* ctx = region->getContext();
|
// An operation folder that is used to attempt folding before inference._
|
||||||
Dialect* tf_dialect = ctx->getRegisteredDialect<TensorFlowDialect>();
|
OperationFolder folder(context_);
|
||||||
|
|
||||||
// An operation folder that is used to attempt folding before inference.
|
|
||||||
OperationFolder folder(ctx);
|
|
||||||
bool changed = true;
|
bool changed = true;
|
||||||
|
|
||||||
// TODO(aminim): we could have a more efficient traversal by guiding the
|
// TODO(aminim): we could have a more efficient traversal by guiding the
|
||||||
@ -908,14 +954,14 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
|||||||
<< "Shape inference, iteration " << iteration << "\n");
|
<< "Shape inference, iteration " << iteration << "\n");
|
||||||
region->walk([&](Operation* op) {
|
region->walk([&](Operation* op) {
|
||||||
if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
|
if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
|
||||||
changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect);
|
changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_);
|
||||||
// TODO(jpienaar): Debug why we can't just return here. We end up with
|
// TODO(jpienaar): Debug why we can't just return here. We end up with
|
||||||
// additional constant due to the propagation of constant into attached
|
// additional constant due to the propagation of constant into attached
|
||||||
// function if we return already.
|
// function if we return already.
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op->getDialect() != tf_dialect) {
|
if (op->getDialect() != tf_dialect_) {
|
||||||
changed |= InferShapeForNonTFDialectOperation(op, tf_dialect);
|
changed |= InferShapeForNonTFDialectOperation(op, tf_dialect_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -924,13 +970,12 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
|||||||
|
|
||||||
// Best-effort shape inference in attached functions. Do not return
|
// Best-effort shape inference in attached functions. Do not return
|
||||||
// failure even if it doesn't get to fixed point.
|
// failure even if it doesn't get to fixed point.
|
||||||
if (failed(PropagateShapeIntoAttachedFunctions(op, graph_version,
|
if (failed(PropagateShapeIntoAttachedFunctions(op, max_iteration))) {
|
||||||
max_iteration))) {
|
|
||||||
op->emitWarning() << "unable to refine shape of attached function "
|
op->emitWarning() << "unable to refine shape of attached function "
|
||||||
"arguments and bodies";
|
"arguments and bodies";
|
||||||
}
|
}
|
||||||
|
|
||||||
changed |= InferShapeForSingleOperation(op, tf_dialect, graph_version);
|
changed |= InferShapeForSingleOperation(op);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -945,31 +990,43 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
|||||||
LogicalResult InferShapeForFunction(FuncOp func,
|
LogicalResult InferShapeForFunction(FuncOp func,
|
||||||
ArrayRef<ArrayRef<int64_t>> arg_shapes,
|
ArrayRef<ArrayRef<int64_t>> arg_shapes,
|
||||||
int64_t graph_version) {
|
int64_t graph_version) {
|
||||||
mlir::FunctionType func_type = func.getType();
|
ShapeInference context(graph_version, func.getContext());
|
||||||
|
if (arg_shapes.empty()) {
|
||||||
|
if (failed(context.InferShapeUntilFixPoint(&func.getBody())))
|
||||||
|
return failure();
|
||||||
|
// TODO(b/156276510): Verify that it is always fine to refine a function's
|
||||||
|
// return type, as long as we do not change the argument shapes.
|
||||||
|
if (auto return_types = InferShapeForFunctionReturnType(func)) {
|
||||||
|
func.setType(FunctionType::get(func.getType().getInputs(),
|
||||||
|
return_types.getValue(),
|
||||||
|
func.getContext()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
FunctionType func_type = func.getType();
|
||||||
bool needs_refinement = false;
|
bool needs_refinement = false;
|
||||||
llvm::SmallVector<mlir::Type, 4> new_arg_types;
|
SmallVector<Type, 4> new_arg_types;
|
||||||
new_arg_types.reserve(func_type.getNumInputs());
|
new_arg_types.reserve(func_type.getNumInputs());
|
||||||
|
|
||||||
// Update argument types in-place using the provided arg_shapes.
|
// Update argument types in-place using the provided arg_shapes.
|
||||||
for (size_t i = 0; i < func_type.getNumInputs(); ++i) {
|
for (size_t i = 0; i < func_type.getNumInputs(); ++i) {
|
||||||
ArrayRef<int64_t> shape = arg_shapes[i];
|
ArrayRef<int64_t> shape = arg_shapes[i];
|
||||||
mlir::Type element_type;
|
Type element_type;
|
||||||
if (auto input_ty =
|
if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
|
||||||
func_type.getInput(i).dyn_cast<mlir::RankedTensorType>()) {
|
|
||||||
if (!input_ty || input_ty.getShape().size() != shape.size()) {
|
if (!input_ty || input_ty.getShape().size() != shape.size()) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
element_type = input_ty.getElementType();
|
element_type = input_ty.getElementType();
|
||||||
} else {
|
} else {
|
||||||
auto unranked_input_ty =
|
auto unranked_input_ty = func_type.getInput(i).dyn_cast<TensorType>();
|
||||||
func_type.getInput(i).dyn_cast<mlir::TensorType>();
|
|
||||||
if (!unranked_input_ty) {
|
if (!unranked_input_ty) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
element_type = unranked_input_ty.getElementType();
|
element_type = unranked_input_ty.getElementType();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto new_arg_type = mlir::RankedTensorType::get(shape, element_type);
|
auto new_arg_type = RankedTensorType::get(shape, element_type);
|
||||||
if (new_arg_type != func_type.getInput(i)) {
|
if (new_arg_type != func_type.getInput(i)) {
|
||||||
// If the new type is more detailed, trigger shape inference.
|
// If the new type is more detailed, trigger shape inference.
|
||||||
func.getArgument(i).setType(new_arg_type);
|
func.getArgument(i).setType(new_arg_type);
|
||||||
@ -982,14 +1039,13 @@ LogicalResult InferShapeForFunction(FuncOp func,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::LogicalResult result =
|
LogicalResult result = context.InferShapeUntilFixPoint(&func.getBody());
|
||||||
mlir::TF::InferShapeUntilFixPoint(&func.getBody(), graph_version);
|
|
||||||
if (failed(result)) {
|
if (failed(result)) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto return_types = InferShapeForFunctionReturnType(func);
|
auto return_types = InferShapeForFunctionReturnType(func);
|
||||||
func.setType(mlir::FunctionType::get(new_arg_types,
|
func.setType(FunctionType::get(new_arg_types,
|
||||||
return_types.hasValue()
|
return_types.hasValue()
|
||||||
? return_types.getValue()
|
? return_types.getValue()
|
||||||
: func.getType().getResults(),
|
: func.getType().getResults(),
|
||||||
@ -998,15 +1054,5 @@ LogicalResult InferShapeForFunction(FuncOp func,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult InferShapeForFunctionType(FuncOp func) {
|
|
||||||
if (auto return_types = InferShapeForFunctionReturnType(func)) {
|
|
||||||
func.setType(mlir::FunctionType::get(func.getType().getInputs(),
|
|
||||||
return_types.getValue(),
|
|
||||||
func.getContext()));
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -27,30 +27,13 @@ namespace mlir {
|
|||||||
|
|
||||||
namespace TF {
|
namespace TF {
|
||||||
|
|
||||||
// Performs shape inference on the provided op and return true if the type of
|
|
||||||
// at least one result has been changed.
|
|
||||||
// A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
|
|
||||||
// `graph_version` indicates the current GraphDef compatibility versions
|
|
||||||
// (the versions field in graph.proto).
|
|
||||||
bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
|
||||||
int64_t graph_version);
|
|
||||||
|
|
||||||
// Infers shape on the provided region, including nested ones, iterate until fix
|
|
||||||
// point with a limit of max_iteration. Returns success if fix point is reached
|
|
||||||
// before max_iteration.
|
|
||||||
LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
|
||||||
int64_t max_iteration = 10);
|
|
||||||
|
|
||||||
// Given a list of refined shapes matching the function arguments of func, runs
|
// Given a list of refined shapes matching the function arguments of func, runs
|
||||||
// shape inference over the function to propagate this updated information.
|
// shape inference over the function to propagate this updated information.
|
||||||
|
// If arg_shapes are empty, then argument shapes will be left unchanged.
|
||||||
LogicalResult InferShapeForFunction(FuncOp func,
|
LogicalResult InferShapeForFunction(FuncOp func,
|
||||||
ArrayRef<ArrayRef<int64_t>> arg_shapes,
|
ArrayRef<ArrayRef<int64_t>> arg_shapes,
|
||||||
int64_t graph_version);
|
int64_t graph_version);
|
||||||
|
|
||||||
// Refines the return type of the given function by folding tf.Cast that
|
|
||||||
// precedes the return instruction.
|
|
||||||
LogicalResult InferShapeForFunctionType(FuncOp func);
|
|
||||||
|
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -58,10 +58,8 @@ struct ShapeInference
|
|||||||
}
|
}
|
||||||
int64_t producer = producer_or.ValueOrDie();
|
int64_t producer = producer_or.ValueOrDie();
|
||||||
for (auto func : module.getOps<FuncOp>()) {
|
for (auto func : module.getOps<FuncOp>()) {
|
||||||
InferShapeUntilFixPoint(&func.getBody(), producer);
|
if (failed(InferShapeForFunction(func, /*arg_shapes=*/{}, producer)))
|
||||||
// TODO(yuanzx): Verify that it is always fine to refine a function's
|
return signalPassFailure();
|
||||||
// return type, as long as we do not change the argument shapes.
|
|
||||||
InferShapeForFunctionType(func);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -14,11 +14,23 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "llvm/ADT/Optional.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Block.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||||
|
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TFTPU {
|
namespace TFTPU {
|
||||||
@ -30,30 +42,182 @@ namespace {
|
|||||||
|
|
||||||
constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
|
constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
|
||||||
|
|
||||||
struct TPUExtractHeadTailOutsideCompilation
|
bool HasOutsideCompilationAttribute(Operation* op) {
|
||||||
: public PassWrapper<TPUExtractHeadTailOutsideCompilation, FunctionPass> {
|
return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
|
||||||
void runOnFunction() override;
|
}
|
||||||
};
|
|
||||||
|
|
||||||
void TPUExtractHeadTailOutsideCompilation::runOnFunction() {
|
// Returns whether all operands of `op` are from values inside the
|
||||||
getFunction().walk([&](tf_device::LaunchOp launch) {
|
// `input_value_set`.
|
||||||
Block& launch_block = launch.GetBody();
|
bool OpContainsOperandsFromSet(Operation* op,
|
||||||
for (auto& op : llvm::make_early_inc_range(launch_block.getOperations())) {
|
const llvm::SetVector<Value>& input_value_set) {
|
||||||
// TODO(b/155115766): Handle outputs that should be inputs to TPU
|
for (auto operand : op->getOperands())
|
||||||
// LaunchOp.
|
if (input_value_set.count(operand) == 0) return false;
|
||||||
if (auto attr =
|
|
||||||
op.getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
|
return true;
|
||||||
op.moveBefore(launch);
|
}
|
||||||
} else {
|
|
||||||
|
void RecordOutsideCompiledOpsAndUsages(
|
||||||
|
Operation* op, llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops,
|
||||||
|
llvm::SetVector<Value>* outside_compiled_op_usages) {
|
||||||
|
if (HasOutsideCompilationAttribute(op) &&
|
||||||
|
OpContainsOperandsFromSet(op, *outside_compiled_op_usages)) {
|
||||||
|
outside_compiled_ops->insert(op);
|
||||||
|
outside_compiled_op_usages->insert(op->getResults().begin(),
|
||||||
|
op->getResults().end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverses the MLIR graph and returns a set of ops that
|
||||||
|
// are connected to inputs of TPU computation and outside compiled.
|
||||||
|
void ExtractOutsideCompiledOpsConnectedToHead(
|
||||||
|
Value input_value, llvm::SetVector<Value>* values_used_in_host_cluster,
|
||||||
|
llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops) {
|
||||||
|
llvm::SmallSetVector<Operation*, 4> parent_outside_compiled_ops_at_head;
|
||||||
|
for (auto& usage : input_value.getUses()) {
|
||||||
|
auto head_operation = usage.getOwner();
|
||||||
|
RecordOutsideCompiledOpsAndUsages(head_operation,
|
||||||
|
&parent_outside_compiled_ops_at_head,
|
||||||
|
values_used_in_host_cluster);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverse the graph and find all outside compiled ops connected from
|
||||||
|
// the `input_value`.
|
||||||
|
while (!parent_outside_compiled_ops_at_head.empty()) {
|
||||||
|
llvm::SmallSetVector<Operation*, 4> connected_outside_compiled_ops;
|
||||||
|
for (auto head_outside_compiled_op : parent_outside_compiled_ops_at_head) {
|
||||||
|
auto op_results = head_outside_compiled_op->getOpResults();
|
||||||
|
for (auto op_result : op_results) {
|
||||||
|
for (auto& use : op_result.getUses()) {
|
||||||
|
auto connected_op = use.getOwner();
|
||||||
|
RecordOutsideCompiledOpsAndUsages(connected_op,
|
||||||
|
&connected_outside_compiled_ops,
|
||||||
|
values_used_in_host_cluster);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
outside_compiled_ops->insert(parent_outside_compiled_ops_at_head.begin(),
|
||||||
|
parent_outside_compiled_ops_at_head.end());
|
||||||
|
std::swap(parent_outside_compiled_ops_at_head,
|
||||||
|
connected_outside_compiled_ops);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(hongjunchoi): Also handle ops without inputs that are outside
|
||||||
|
// compiled.
|
||||||
|
//
|
||||||
|
// Returns set of ops that are outside compiled and are directly connected
|
||||||
|
// to inputs to the TPU computation.
|
||||||
|
llvm::SmallSetVector<Operation*, 4> IdentifyOutsideCompiledOpsAtHead(
|
||||||
|
tf_device::ClusterOp tpu_cluster) {
|
||||||
|
llvm::SmallSetVector<Operation*, 4> outside_compiled_at_head_ops;
|
||||||
|
llvm::SetVector<Value> values_used_in_cluster;
|
||||||
|
auto& cluster_region = tpu_cluster.body();
|
||||||
|
getUsedValuesDefinedAbove(cluster_region, cluster_region,
|
||||||
|
values_used_in_cluster);
|
||||||
|
|
||||||
|
auto input_value_list = llvm::to_vector<8>(values_used_in_cluster);
|
||||||
|
for (auto input_value : input_value_list)
|
||||||
|
ExtractOutsideCompiledOpsConnectedToHead(
|
||||||
|
input_value, &values_used_in_cluster, &outside_compiled_at_head_ops);
|
||||||
|
return outside_compiled_at_head_ops;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns output values of extracted outside compiled cluster at head that
|
||||||
|
// are used by the TPU computation.
|
||||||
|
llvm::SmallVector<Value, 8> GetHeadExtractedClusterOutputs(
|
||||||
|
const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
|
||||||
|
llvm::SmallVector<Value, 8> outputs;
|
||||||
|
outputs.reserve(head_outside_compiled_ops.size());
|
||||||
|
|
||||||
|
for (auto op : head_outside_compiled_ops) {
|
||||||
|
for (Operation* user : op->getUsers()) {
|
||||||
|
if (!head_outside_compiled_ops.count(user)) {
|
||||||
|
outputs.append(op->result_begin(), op->result_end());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates new tf_device.launch op with outside compiled ops extracted
|
||||||
|
// from the head of TPU computation.
|
||||||
|
llvm::Optional<tf_device::LaunchOp> IsolateHeadExtractedOpsToLaunchOp(
|
||||||
|
OpBuilder* builder, tf_device::ClusterOp cluster,
|
||||||
|
const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
|
||||||
|
if (head_outside_compiled_ops.empty())
|
||||||
|
return llvm::Optional<tf_device::LaunchOp>();
|
||||||
|
|
||||||
|
// Create tf_device.launch op to separate all extracted outside compiled ops
|
||||||
|
// before the tf_device.cluster.
|
||||||
|
auto output_values =
|
||||||
|
GetHeadExtractedClusterOutputs(head_outside_compiled_ops);
|
||||||
|
|
||||||
|
llvm::SmallVector<Type, 8> output_return_types;
|
||||||
|
output_return_types.reserve(output_values.size());
|
||||||
|
for (auto output : output_values)
|
||||||
|
output_return_types.emplace_back(output.getType());
|
||||||
|
|
||||||
|
builder->setInsertionPoint(cluster);
|
||||||
|
auto host_launch_op = builder->create<tf_device::LaunchOp>(
|
||||||
|
cluster.getLoc(), builder->getStringAttr(""), output_return_types);
|
||||||
|
|
||||||
|
// Replace all usages of outside compiled ops that are used in TPU
|
||||||
|
// computation with the results of the above created launch op.
|
||||||
|
for (auto output_and_index : llvm::enumerate(output_values)) {
|
||||||
|
auto output_index = output_and_index.index();
|
||||||
|
auto output = output_and_index.value();
|
||||||
|
for (auto& use : output.getUses()) {
|
||||||
|
if (!head_outside_compiled_ops.count(use.getOwner()))
|
||||||
|
use.set(host_launch_op.getResult(output_index));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create terminator op for the newly created launch op.
|
||||||
|
host_launch_op.body().push_back(new Block());
|
||||||
|
builder->setInsertionPointToEnd(&host_launch_op.GetBody());
|
||||||
|
auto terminator = builder->create<tf_device::ReturnOp>(
|
||||||
|
host_launch_op.getLoc(), output_values);
|
||||||
|
|
||||||
|
// Move all outside compile ops from cluster op to launch op.
|
||||||
|
for (auto outside_compiled_op : head_outside_compiled_ops)
|
||||||
|
outside_compiled_op->moveBefore(terminator);
|
||||||
|
|
||||||
|
return host_launch_op;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TPUExtractHeadTailOutsideCompilation
|
||||||
|
: public PassWrapper<TPUExtractHeadTailOutsideCompilation,
|
||||||
|
OperationPass<ModuleOp>> {
|
||||||
|
void runOnOperation() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
|
||||||
|
// Get runtime devices information from the closest parent module.
|
||||||
|
auto module = getOperation();
|
||||||
|
mlir::TF::RuntimeDevices devices;
|
||||||
|
if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
|
||||||
|
return signalPassFailure();
|
||||||
|
|
||||||
|
OpBuilder builder(&getContext());
|
||||||
|
module.walk([&](tf_device::ClusterOp cluster) {
|
||||||
|
auto head_outside_compiled_ops = IdentifyOutsideCompiledOpsAtHead(cluster);
|
||||||
|
IsolateHeadExtractedOpsToLaunchOp(&builder, cluster,
|
||||||
|
head_outside_compiled_ops);
|
||||||
|
|
||||||
|
// TODO(b/156030523): Update device attribute of newly created host launch
|
||||||
|
// op as well as enclosing Replicate op (if TPU computation is replicated)
|
||||||
|
// with host device names.
|
||||||
|
|
||||||
|
// TODO(b/155115766): Implement tail outside compiled op extraction.
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
CreateTPUExtractHeadTailOutsideCompilationPass() {
|
CreateTPUExtractHeadTailOutsideCompilationPass() {
|
||||||
return std::make_unique<TPUExtractHeadTailOutsideCompilation>();
|
return std::make_unique<TPUExtractHeadTailOutsideCompilation>();
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
|
|||||||
constexpr char kDeviceAttr[] = "device";
|
constexpr char kDeviceAttr[] = "device";
|
||||||
|
|
||||||
// Mapping for `_xla_outside_compilation` attribute to ops of a cluster.
|
// Mapping for `_xla_outside_compilation` attribute to ops of a cluster.
|
||||||
using ClusterMap =
|
using OutsideClusterMap =
|
||||||
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<Operation*, 8>, 8>;
|
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<Operation*, 8>, 8>;
|
||||||
|
|
||||||
// This pass extracts a CPU computation cluster with `_xla_outside_compilation`
|
// This pass extracts a CPU computation cluster with `_xla_outside_compilation`
|
||||||
@ -51,7 +51,8 @@ struct TPUExtractOutsideCompilation
|
|||||||
// Collects and clusters ops in `block` with the same `_xla_outside_compilation`
|
// Collects and clusters ops in `block` with the same `_xla_outside_compilation`
|
||||||
// attribute into `clusters` This returns an error if a
|
// attribute into `clusters` This returns an error if a
|
||||||
// `_xla_outside_compilation` attribute of an op is empty.
|
// `_xla_outside_compilation` attribute of an op is empty.
|
||||||
LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) {
|
LogicalResult CollectAndGroupOutsideClusterOps(Block* block,
|
||||||
|
OutsideClusterMap* clusters) {
|
||||||
for (Operation& op : *block) {
|
for (Operation& op : *block) {
|
||||||
if (auto attr = op.getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
|
if (auto attr = op.getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
|
||||||
if (attr.getValue().empty())
|
if (attr.getValue().empty())
|
||||||
@ -67,7 +68,7 @@ LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Moves `cluster_ops` to associated `launch_op` body.
|
// Moves `cluster_ops` to associated `launch_op` body.
|
||||||
void MoveClusterOpsToLaunchOp(
|
void MoveOutsideClusterOpsToLaunchOp(
|
||||||
tf_device::LaunchOp launch_op,
|
tf_device::LaunchOp launch_op,
|
||||||
const llvm::SmallVector<Operation*, 8>& cluster_ops) {
|
const llvm::SmallVector<Operation*, 8>& cluster_ops) {
|
||||||
MLIRContext* context = launch_op.getContext();
|
MLIRContext* context = launch_op.getContext();
|
||||||
@ -84,8 +85,8 @@ void MoveClusterOpsToLaunchOp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Creates a `tf_device::LaunchOp` to wrap cluster ops.
|
// Creates a `tf_device::LaunchOp` to wrap cluster ops.
|
||||||
tf_device::LaunchOp CreateLaunchOpForCluster(OpBuilder* builder,
|
tf_device::LaunchOp CreateLaunchOpForOutsideCluster(
|
||||||
Operation* last_cluster_op) {
|
OpBuilder* builder, Operation* last_cluster_op) {
|
||||||
// TODO(b/154363171): Set the CPU device.
|
// TODO(b/154363171): Set the CPU device.
|
||||||
// An empty string placeholder is used for the device as that will be later
|
// An empty string placeholder is used for the device as that will be later
|
||||||
// populated with the device of the associated TPUReplicateMetadata op.
|
// populated with the device of the associated TPUReplicateMetadata op.
|
||||||
@ -117,14 +118,14 @@ void PropagateParallelExecuteReturnToReplicate(
|
|||||||
|
|
||||||
// Creates a `parallel_execute` op in place of launch with 'clusters` and
|
// Creates a `parallel_execute` op in place of launch with 'clusters` and
|
||||||
// 'launch` as regions.
|
// 'launch` as regions.
|
||||||
void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch,
|
void CreateParallelExecuteFromOutsideClusters(
|
||||||
const ClusterMap& clusters) {
|
tf_device::ClusterOp tpu_cluster, const OutsideClusterMap& clusters) {
|
||||||
OpBuilder builder(launch);
|
OpBuilder builder(tpu_cluster);
|
||||||
// Create parallel_execute regions. The original TPU cluster computation
|
// Create parallel_execute regions. The original TPU cluster computation
|
||||||
// is the extra region.
|
// is the extra region.
|
||||||
int num_regions = 1 + clusters.size();
|
int num_regions = 1 + clusters.size();
|
||||||
auto parallel_execute_op = builder.create<tf_device::ParallelExecuteOp>(
|
auto parallel_execute_op = builder.create<tf_device::ParallelExecuteOp>(
|
||||||
launch.getLoc(), num_regions, launch.results().getTypes());
|
tpu_cluster.getLoc(), num_regions, tpu_cluster.results().getTypes());
|
||||||
|
|
||||||
// Move outside compilation clusters to parallel_execute regions.
|
// Move outside compilation clusters to parallel_execute regions.
|
||||||
for (const auto& cluster : llvm::enumerate(clusters)) {
|
for (const auto& cluster : llvm::enumerate(clusters)) {
|
||||||
@ -134,21 +135,23 @@ void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch,
|
|||||||
parallel_execute_op.GetRegionBlockWithIndex(cluster.index());
|
parallel_execute_op.GetRegionBlockWithIndex(cluster.index());
|
||||||
builder.setInsertionPointToEnd(&outside_block);
|
builder.setInsertionPointToEnd(&outside_block);
|
||||||
tf_device::LaunchOp launch_op =
|
tf_device::LaunchOp launch_op =
|
||||||
CreateLaunchOpForCluster(&builder, cluster_ops.back());
|
CreateLaunchOpForOutsideCluster(&builder, cluster_ops.back());
|
||||||
MoveClusterOpsToLaunchOp(launch_op, cluster_ops);
|
MoveOutsideClusterOpsToLaunchOp(launch_op, cluster_ops);
|
||||||
builder.setInsertionPointToEnd(&outside_block);
|
builder.setInsertionPointToEnd(&outside_block);
|
||||||
// TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute
|
// TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute
|
||||||
// regions either through communication with TPU parallel_execute regions
|
// regions either through communication with TPU parallel_execute regions
|
||||||
// or modifying parallel_execute returns.
|
// or modifying parallel_execute returns.
|
||||||
builder.create<tf_device::ReturnOp>(launch.getLoc(), ArrayRef<Value>{});
|
builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(),
|
||||||
|
ArrayRef<Value>{});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Move the launch body to last parallel_execute block.
|
// Move the launch body to last parallel_execute block.
|
||||||
Block& inside_block =
|
Block& inside_block =
|
||||||
parallel_execute_op.GetRegionBlockWithIndex(num_regions - 1);
|
parallel_execute_op.GetRegionBlockWithIndex(num_regions - 1);
|
||||||
builder.setInsertionPointToEnd(&inside_block);
|
builder.setInsertionPointToEnd(&inside_block);
|
||||||
builder.create<tf_device::ReturnOp>(launch.getLoc(), launch.getResults());
|
builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(),
|
||||||
launch.getOperation()->moveBefore(inside_block.getTerminator());
|
tpu_cluster.getResults());
|
||||||
|
tpu_cluster.getOperation()->moveBefore(inside_block.getTerminator());
|
||||||
|
|
||||||
PropagateParallelExecuteReturnToReplicate(parallel_execute_op);
|
PropagateParallelExecuteReturnToReplicate(parallel_execute_op);
|
||||||
// TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute
|
// TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute
|
||||||
@ -157,14 +160,16 @@ void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TPUExtractOutsideCompilation::runOnFunction() {
|
void TPUExtractOutsideCompilation::runOnFunction() {
|
||||||
auto extract_result = getFunction().walk([&](tf_device::LaunchOp launch) {
|
auto extract_result =
|
||||||
ClusterMap clusters;
|
getFunction().walk([&](tf_device::ClusterOp tpu_cluster) {
|
||||||
if (failed(CollectAndGroupClusterOps(&launch.GetBody(), &clusters)))
|
OutsideClusterMap clusters;
|
||||||
|
if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
|
||||||
|
&clusters)))
|
||||||
return WalkResult::interrupt();
|
return WalkResult::interrupt();
|
||||||
|
|
||||||
if (clusters.empty()) return WalkResult::advance();
|
if (clusters.empty()) return WalkResult::advance();
|
||||||
|
|
||||||
CreateParallelExecuteFromClusters(launch, clusters);
|
CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters);
|
||||||
|
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
});
|
});
|
||||||
|
@ -92,7 +92,7 @@ constexpr char kBadArrayAttrLengthMsg[] =
|
|||||||
//
|
//
|
||||||
// Would become following ops (unimportant attributes, types are omitted):
|
// Would become following ops (unimportant attributes, types are omitted):
|
||||||
// %1 = "tf.Shape"(%0)
|
// %1 = "tf.Shape"(%0)
|
||||||
// %2:2 = "tf.MLIRCompileToTPU"(%1) {module = "<Serialized @tpu_func>"}
|
// %2:2 = "tf._TPUCompileMlir"(%1) {module = "<Serialized @tpu_func>"}
|
||||||
// "tf.TPUCompileSucceededAssert"(%2#0)
|
// "tf.TPUCompileSucceededAssert"(%2#0)
|
||||||
// %3 = "tf.TPUExecute"(%0, %2#1)
|
// %3 = "tf.TPUExecute"(%0, %2#1)
|
||||||
// %4 = "tf.SomeOp"(%3)
|
// %4 = "tf.SomeOp"(%3)
|
||||||
@ -448,19 +448,20 @@ Operation* BuildCompileOp(
|
|||||||
// core, and all replica devices per core are grouped together.
|
// core, and all replica devices per core are grouped together.
|
||||||
void AssignDevicesToReplicate(
|
void AssignDevicesToReplicate(
|
||||||
tf_device::ReplicateOp replicate,
|
tf_device::ReplicateOp replicate,
|
||||||
llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
|
llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
|
||||||
|
tpu_devices,
|
||||||
OpBuilder* builder) {
|
OpBuilder* builder) {
|
||||||
if (!replicate) return;
|
if (!replicate) return;
|
||||||
|
|
||||||
const int num_replicas = execution_devices.size();
|
const int num_replicas = tpu_devices.size();
|
||||||
const int num_cores_per_replica = execution_devices.front().size();
|
const int num_cores_per_replica = tpu_devices.front().size();
|
||||||
|
|
||||||
llvm::SmallVector<NamedAttribute, 8> device_attrs;
|
llvm::SmallVector<NamedAttribute, 8> device_attrs;
|
||||||
for (int core = 0; core < num_cores_per_replica; ++core) {
|
for (int core = 0; core < num_cores_per_replica; ++core) {
|
||||||
llvm::SmallVector<StringRef, 8> devices_by_core;
|
llvm::SmallVector<StringRef, 8> devices_by_core;
|
||||||
devices_by_core.reserve(num_replicas);
|
devices_by_core.reserve(num_replicas);
|
||||||
for (int replica = 0; replica < num_replicas; ++replica)
|
for (int replica = 0; replica < num_replicas; ++replica)
|
||||||
devices_by_core.push_back(execution_devices[replica][core]);
|
devices_by_core.push_back(tpu_devices[replica][core].device);
|
||||||
|
|
||||||
device_attrs.push_back(
|
device_attrs.push_back(
|
||||||
builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core),
|
builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core),
|
||||||
@ -492,11 +493,12 @@ LogicalResult BuildExecuteOp(
|
|||||||
// Creates a tf_device.parallel_execute op that wraps TPUExecute op to
|
// Creates a tf_device.parallel_execute op that wraps TPUExecute op to
|
||||||
// represent execution of TPU program in multiple logical cores.
|
// represent execution of TPU program in multiple logical cores.
|
||||||
LogicalResult BuildParallelExecuteOp(
|
LogicalResult BuildParallelExecuteOp(
|
||||||
llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
|
llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
|
||||||
|
tpu_devices,
|
||||||
llvm::ArrayRef<xla::OpSharding> output_sharding_config,
|
llvm::ArrayRef<xla::OpSharding> output_sharding_config,
|
||||||
Operation* compile_op, tf_device::ClusterFuncOp cluster_func,
|
Operation* compile_op, tf_device::ClusterFuncOp cluster_func,
|
||||||
OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) {
|
OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) {
|
||||||
const int num_cores_per_replica = execution_devices.front().size();
|
const int num_cores_per_replica = tpu_devices.front().size();
|
||||||
// parallel_execute op returns concatenated list of return values of
|
// parallel_execute op returns concatenated list of return values of
|
||||||
// all its regions.
|
// all its regions.
|
||||||
//
|
//
|
||||||
@ -528,7 +530,7 @@ LogicalResult BuildParallelExecuteOp(
|
|||||||
num_cores_per_replica, cluster_func, builder, &input_list);
|
num_cores_per_replica, cluster_func, builder, &input_list);
|
||||||
if (failed(result)) return failure();
|
if (failed(result)) return failure();
|
||||||
|
|
||||||
const bool replicated = execution_devices.size() != 1;
|
const bool replicated = tpu_devices.size() != 1;
|
||||||
// For each logical core, create a region with TPUExecute op.
|
// For each logical core, create a region with TPUExecute op.
|
||||||
assert(input_list.size() == num_cores_per_replica);
|
assert(input_list.size() == num_cores_per_replica);
|
||||||
for (int core = 0; core < num_cores_per_replica; ++core) {
|
for (int core = 0; core < num_cores_per_replica; ++core) {
|
||||||
@ -553,7 +555,7 @@ LogicalResult BuildParallelExecuteOp(
|
|||||||
// op.
|
// op.
|
||||||
std::string device = replicated
|
std::string device = replicated
|
||||||
? tensorflow::GetDeviceAliasForLogicalCore(core)
|
? tensorflow::GetDeviceAliasForLogicalCore(core)
|
||||||
: execution_devices.front()[core];
|
: tpu_devices.front()[core].device;
|
||||||
|
|
||||||
auto region_launch_op =
|
auto region_launch_op =
|
||||||
WrapOpInLaunch(builder, region.getParent()->getLoc(), execute, device);
|
WrapOpInLaunch(builder, region.getParent()->getLoc(), execute, device);
|
||||||
@ -566,13 +568,14 @@ LogicalResult BuildParallelExecuteOp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
tf_device::LaunchOp AssignDevicesToReplicatedExecute(
|
tf_device::LaunchOp AssignDevicesToReplicatedExecute(
|
||||||
llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
|
llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
|
||||||
|
tpu_devices,
|
||||||
Operation* execute_op, OpBuilder* builder) {
|
Operation* execute_op, OpBuilder* builder) {
|
||||||
const bool replicated = execution_devices.size() != 1;
|
const bool replicated = tpu_devices.size() != 1;
|
||||||
// If computation is replicated, use aliased device. Otherwise there is only
|
// If computation is replicated, use aliased device. Otherwise there is only
|
||||||
// one execution device and the device is assigned to the execute op.
|
// one execution device and the device is assigned to the execute op.
|
||||||
std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0)
|
std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0)
|
||||||
: execution_devices.front().front();
|
: tpu_devices.front().front().device;
|
||||||
|
|
||||||
return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device);
|
return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device);
|
||||||
}
|
}
|
||||||
@ -687,6 +690,16 @@ LogicalResult Rewrite(
|
|||||||
// Create compile op.
|
// Create compile op.
|
||||||
auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
|
auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
|
||||||
builder->setInsertionPoint(cluster_func);
|
builder->setInsertionPoint(cluster_func);
|
||||||
|
|
||||||
|
// Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
|
||||||
|
// parallel_execute region if it exists.
|
||||||
|
if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func.getParentOp())) {
|
||||||
|
// Currently, outside compilation and model parallelism are not supported
|
||||||
|
// together.
|
||||||
|
assert(num_cores_per_replica == 1);
|
||||||
|
builder->setInsertionPoint(cluster_func.getParentOp());
|
||||||
|
}
|
||||||
|
|
||||||
Operation* compile_op = BuildCompileOp(
|
Operation* compile_op = BuildCompileOp(
|
||||||
cluster_func, num_replicas, num_cores_per_replica,
|
cluster_func, num_replicas, num_cores_per_replica,
|
||||||
tpu_device_assignment.compilation_device,
|
tpu_device_assignment.compilation_device,
|
||||||
@ -704,7 +717,7 @@ LogicalResult Rewrite(
|
|||||||
BuildTPUCompileSucceededAssertOp(
|
BuildTPUCompileSucceededAssertOp(
|
||||||
compile_op, tpu_device_assignment.compilation_device, builder);
|
compile_op, tpu_device_assignment.compilation_device, builder);
|
||||||
|
|
||||||
AssignDevicesToReplicate(replicate, tpu_device_assignment.execution_devices,
|
AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices,
|
||||||
builder);
|
builder);
|
||||||
|
|
||||||
llvm::SmallVector<xla::OpSharding, 4> output_shardings;
|
llvm::SmallVector<xla::OpSharding, 4> output_shardings;
|
||||||
@ -712,12 +725,13 @@ LogicalResult Rewrite(
|
|||||||
num_cores_per_replica, cluster_func, &output_shardings);
|
num_cores_per_replica, cluster_func, &output_shardings);
|
||||||
if (failed(result)) return failure();
|
if (failed(result)) return failure();
|
||||||
|
|
||||||
|
builder->setInsertionPoint(cluster_func);
|
||||||
if (num_cores_per_replica > 1) {
|
if (num_cores_per_replica > 1) {
|
||||||
// For model parallelism, tf_device.parallel_execute is used to express
|
// For model parallelism, tf_device.parallel_execute is used to express
|
||||||
// concurrent device execution across multiple logical devices.
|
// concurrent device execution across multiple logical devices.
|
||||||
|
|
||||||
tf_device::ParallelExecuteOp execute_op;
|
tf_device::ParallelExecuteOp execute_op;
|
||||||
result = BuildParallelExecuteOp(tpu_device_assignment.execution_devices,
|
result = BuildParallelExecuteOp(tpu_device_assignment.tpu_devices,
|
||||||
output_shardings, compile_op, cluster_func,
|
output_shardings, compile_op, cluster_func,
|
||||||
builder, &execute_op);
|
builder, &execute_op);
|
||||||
if (failed(result)) return failure();
|
if (failed(result)) return failure();
|
||||||
@ -740,7 +754,7 @@ LogicalResult Rewrite(
|
|||||||
if (failed(result)) return failure();
|
if (failed(result)) return failure();
|
||||||
|
|
||||||
tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute(
|
tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute(
|
||||||
tpu_device_assignment.execution_devices, execute_op, builder);
|
tpu_device_assignment.tpu_devices, execute_op, builder);
|
||||||
cluster_func.replaceAllUsesWith(launch_op);
|
cluster_func.replaceAllUsesWith(launch_op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "llvm/ADT/StringSet.h"
|
||||||
#include "llvm/ADT/Twine.h"
|
#include "llvm/ADT/Twine.h"
|
||||||
#include "llvm/Support/SourceMgr.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
@ -57,6 +58,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Types.h" // from @llvm-project
|
#include "mlir/IR/Types.h" // from @llvm-project
|
||||||
#include "mlir/IR/Verifier.h" // from @llvm-project
|
#include "mlir/IR/Verifier.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
||||||
@ -65,6 +67,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||||
@ -109,6 +112,7 @@ static inline absl::string_view StringRefToView(llvm::StringRef ref) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
using mlir::NamedAttrList;
|
||||||
using mlir::TensorType;
|
using mlir::TensorType;
|
||||||
using mlir::TF::VarHandleOp;
|
using mlir::TF::VarHandleOp;
|
||||||
using mlir::tf_saved_model::GlobalTensorOp;
|
using mlir::tf_saved_model::GlobalTensorOp;
|
||||||
@ -306,9 +310,9 @@ class ImporterBase {
|
|||||||
// AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to
|
// AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to
|
||||||
// a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar},
|
// a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar},
|
||||||
// {base_name.k2 : rfc}}.
|
// {base_name.k2 : rfc}}.
|
||||||
Status ConvertFunctionCallAttribute(
|
Status ConvertFunctionCallAttribute(const std::string& base_name,
|
||||||
const std::string& base_name, const AttrValue& value,
|
const AttrValue& value,
|
||||||
llvm::SmallVector<mlir::NamedAttribute, 4>* attributes);
|
NamedAttrList* attributes);
|
||||||
|
|
||||||
// Helper to create either a tf_executor operation or a TF operation wrapped
|
// Helper to create either a tf_executor operation or a TF operation wrapped
|
||||||
// in an island. When convert_to_legacy_call is true, converts the operation
|
// in an island. When convert_to_legacy_call is true, converts the operation
|
||||||
@ -1089,9 +1093,9 @@ StatusOr<ImporterBase::ElementSubtypes> ImporterBase::ConvertSubtypes(
|
|||||||
return subtypes;
|
return subtypes;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ImporterBase::ConvertFunctionCallAttribute(
|
Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name,
|
||||||
const std::string& base_name, const AttrValue& value,
|
const AttrValue& value,
|
||||||
llvm::SmallVector<mlir::NamedAttribute, 4>* attributes) {
|
NamedAttrList* attributes) {
|
||||||
TF_ASSIGN_OR_RETURN(auto func_attr,
|
TF_ASSIGN_OR_RETURN(auto func_attr,
|
||||||
ConvertFunctionCallName(value.func().name()));
|
ConvertFunctionCallName(value.func().name()));
|
||||||
attributes->push_back(builder_.getNamedAttr(base_name, func_attr));
|
attributes->push_back(builder_.getNamedAttr(base_name, func_attr));
|
||||||
@ -2428,8 +2432,8 @@ class SavedModelObjectGraphImporter : public ImporterBase {
|
|||||||
// Main entry point: converts all functions in the given meta graph to an MLIR
|
// Main entry point: converts all functions in the given meta graph to an MLIR
|
||||||
// Module.
|
// Module.
|
||||||
static StatusOr<mlir::OwningModuleRef> Convert(
|
static StatusOr<mlir::OwningModuleRef> Convert(
|
||||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
|
||||||
absl::Span<std::string> exported_names, bool add_default_attributes);
|
mlir::MLIRContext* context, bool add_default_attributes);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit SavedModelObjectGraphImporter(
|
explicit SavedModelObjectGraphImporter(
|
||||||
@ -3129,8 +3133,8 @@ Status CreateSavedModelIR(
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
|
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
|
||||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
|
||||||
absl::Span<std::string> exported_names, bool add_default_attributes) {
|
mlir::MLIRContext* context, bool add_default_attributes) {
|
||||||
GraphDebugInfo dummy_debug_info;
|
GraphDebugInfo dummy_debug_info;
|
||||||
const GraphDebugInfo& debug_info =
|
const GraphDebugInfo& debug_info =
|
||||||
saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
|
saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
|
||||||
@ -3207,17 +3211,20 @@ class SavedModelSignatureDefImporter {
|
|||||||
public:
|
public:
|
||||||
// Main entry point: converts all functions (specified by SignatureDefs) in
|
// Main entry point: converts all functions (specified by SignatureDefs) in
|
||||||
// the given meta graph to an MLIR Module.
|
// the given meta graph to an MLIR Module.
|
||||||
static StatusOr<mlir::OwningModuleRef> Convert(const SavedModelBundle& bundle,
|
static StatusOr<mlir::OwningModuleRef> Convert(
|
||||||
|
const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
|
||||||
mlir::MLIRContext* context) {
|
mlir::MLIRContext* context) {
|
||||||
SavedModelSignatureDefImporter importer(bundle, context);
|
SavedModelSignatureDefImporter importer(bundle, exported_names, context);
|
||||||
|
|
||||||
return importer.ConvertSignatures();
|
return importer.ConvertSignatures();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SavedModelSignatureDefImporter(const SavedModelBundle& bundle,
|
SavedModelSignatureDefImporter(const SavedModelBundle& bundle,
|
||||||
|
absl::Span<std::string> exported_names,
|
||||||
mlir::MLIRContext* context)
|
mlir::MLIRContext* context)
|
||||||
: bundle_(bundle),
|
: bundle_(bundle),
|
||||||
|
exported_names_(exported_names),
|
||||||
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {}
|
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {}
|
||||||
|
|
||||||
// Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
|
// Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
|
||||||
@ -3250,6 +3257,7 @@ class SavedModelSignatureDefImporter {
|
|||||||
const std::vector<std::pair<std::string, TensorInfo>>& inputs);
|
const std::vector<std::pair<std::string, TensorInfo>>& inputs);
|
||||||
|
|
||||||
const SavedModelBundle& bundle_;
|
const SavedModelBundle& bundle_;
|
||||||
|
absl::Span<std::string> exported_names_;
|
||||||
mlir::OwningModuleRef module_;
|
mlir::OwningModuleRef module_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -3265,6 +3273,9 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
|
|||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info;
|
if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info;
|
||||||
|
|
||||||
|
llvm::StringSet<> exported_name_set;
|
||||||
|
exported_name_set.insert(exported_names_.begin(), exported_names_.end());
|
||||||
|
|
||||||
for (const auto& key_and_signature_def : signatures) {
|
for (const auto& key_and_signature_def : signatures) {
|
||||||
const std::string& sig_def_key = key_and_signature_def.first;
|
const std::string& sig_def_key = key_and_signature_def.first;
|
||||||
const SignatureDef& signature_def = key_and_signature_def.second;
|
const SignatureDef& signature_def = key_and_signature_def.second;
|
||||||
@ -3274,6 +3285,10 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
|
|||||||
if (sig_def_key == "__saved_model_init_op") {
|
if (sig_def_key == "__saved_model_init_op") {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (!exported_name_set.empty() &&
|
||||||
|
exported_name_set.count(sig_def_key) == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def,
|
TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def,
|
||||||
debug_info, flib_def));
|
debug_info, flib_def));
|
||||||
@ -3556,12 +3571,14 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
|||||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
||||||
absl::Span<std::string> exported_names, bool add_default_attributes) {
|
absl::Span<std::string> exported_names, bool add_default_attributes) {
|
||||||
return SavedModelObjectGraphImporter::Convert(
|
return SavedModelObjectGraphImporter::Convert(
|
||||||
saved_model, context, exported_names, add_default_attributes);
|
saved_model, exported_names, context, add_default_attributes);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
|
StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
|
||||||
const SavedModelBundle& saved_model, mlir::MLIRContext* context) {
|
const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
|
||||||
return SavedModelSignatureDefImporter::Convert(saved_model, context);
|
mlir::MLIRContext* context) {
|
||||||
|
return SavedModelSignatureDefImporter::Convert(saved_model, exported_names,
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
|
std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
|
||||||
|
@ -55,6 +55,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
|||||||
// expressed with tf_executor dialect.
|
// expressed with tf_executor dialect.
|
||||||
stream_executor::port::StatusOr<mlir::OwningModuleRef>
|
stream_executor::port::StatusOr<mlir::OwningModuleRef>
|
||||||
ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
|
ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
|
||||||
|
absl::Span<std::string> exported_names,
|
||||||
mlir::MLIRContext* context);
|
mlir::MLIRContext* context);
|
||||||
|
|
||||||
// Serialize a MLIR module to a string.
|
// Serialize a MLIR module to a string.
|
||||||
|
@ -141,7 +141,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
|||||||
|
|
||||||
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||||
absl::string_view saved_model_dir,
|
absl::string_view saved_model_dir,
|
||||||
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context) {
|
const std::unordered_set<std::string>& tags,
|
||||||
|
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
|
||||||
tensorflow::SavedModelBundle bundle;
|
tensorflow::SavedModelBundle bundle;
|
||||||
tensorflow::SessionOptions session_options;
|
tensorflow::SessionOptions session_options;
|
||||||
// Force saved model states to be restored to CPU.
|
// Force saved model states to be restored to CPU.
|
||||||
@ -155,7 +156,7 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, context);
|
auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context);
|
||||||
if (!module_or.status().ok()) {
|
if (!module_or.status().ok()) {
|
||||||
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
|
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -64,7 +64,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
|||||||
// given MLIR `context`.
|
// given MLIR `context`.
|
||||||
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||||
absl::string_view saved_model_dir,
|
absl::string_view saved_model_dir,
|
||||||
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context);
|
const std::unordered_set<std::string>& tags,
|
||||||
|
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -293,6 +293,12 @@ Status ConvertMLIRToXlaComputation(
|
|||||||
tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type));
|
tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type));
|
||||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||||
|
|
||||||
|
// Run shape inference pass to propagate shapes through tensor_cast operations
|
||||||
|
// from static to dynamic shapes. This could be generated if the shape
|
||||||
|
// inference was originally missing in a TF op but the corresponding HLO op
|
||||||
|
// had static shape after lowering.
|
||||||
|
tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||||
|
|
||||||
// Run LegalizeTFPass again because the previous legalization passes can
|
// Run LegalizeTFPass again because the previous legalization passes can
|
||||||
// expose more graph pruning and canonicalization opportunities that are
|
// expose more graph pruning and canonicalization opportunities that are
|
||||||
// necessary for the second LegalizeTFPass(allow_partial_conversion=false)
|
// necessary for the second LegalizeTFPass(allow_partial_conversion=false)
|
||||||
|
@ -31,12 +31,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
#include "tensorflow/core/platform/tstring.h"
|
#include "tensorflow/core/platform/tstring.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
@ -131,13 +133,21 @@ StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
|
|||||||
case DTYPE: \
|
case DTYPE: \
|
||||||
return ConvertFlatTensor<CTYPE>(input_tensor, type);
|
return ConvertFlatTensor<CTYPE>(input_tensor, type);
|
||||||
|
|
||||||
// TODO(fengliuai): customize the conversions for more types.
|
// TODO(fengliuai): customize the conversions for quantized and string types.
|
||||||
switch (input_dtype) {
|
switch (input_dtype) {
|
||||||
CONVERT_FLAT(DT_BOOL, bool)
|
CONVERT_FLAT(DT_BOOL, bool)
|
||||||
CONVERT_FLAT(DT_FLOAT, float)
|
CONVERT_FLAT(DT_FLOAT, float)
|
||||||
CONVERT_FLAT(DT_DOUBLE, double)
|
CONVERT_FLAT(DT_DOUBLE, double)
|
||||||
|
CONVERT_FLAT(DT_INT8, int8)
|
||||||
|
CONVERT_FLAT(DT_INT16, int16)
|
||||||
CONVERT_FLAT(DT_INT32, int32)
|
CONVERT_FLAT(DT_INT32, int32)
|
||||||
CONVERT_FLAT(DT_INT64, int64)
|
CONVERT_FLAT(DT_INT64, int64)
|
||||||
|
CONVERT_FLAT(DT_UINT8, uint8)
|
||||||
|
CONVERT_FLAT(DT_UINT16, uint16)
|
||||||
|
CONVERT_FLAT(DT_UINT32, uint32)
|
||||||
|
CONVERT_FLAT(DT_UINT64, uint64)
|
||||||
|
CONVERT_FLAT(DT_COMPLEX64, std::complex<float>)
|
||||||
|
CONVERT_FLAT(DT_COMPLEX128, std::complex<double>)
|
||||||
|
|
||||||
// BFLOAT16 is a special case that it needs to be cast to double type to
|
// BFLOAT16 is a special case that it needs to be cast to double type to
|
||||||
// match its storage type.
|
// match its storage type.
|
||||||
@ -207,12 +217,20 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) {
|
|||||||
|
|
||||||
// Converts an MLIR dense string elements attribute to a TensorFlow tensor
|
// Converts an MLIR dense string elements attribute to a TensorFlow tensor
|
||||||
// proto.
|
// proto.
|
||||||
Status ConvertStringElementsAttr(const DenseStringElementsAttr attr,
|
void ConvertStringElementsAttr(
|
||||||
TensorProto* output_tensor) {
|
const DenseStringElementsAttr attr,
|
||||||
for (const auto& val : attr.getRawStringData()) {
|
protobuf::RepeatedPtrField<std::string>* output) {
|
||||||
output_tensor->add_string_val(val.data(), val.size());
|
for (const auto& val : attr.getRawStringData())
|
||||||
|
output->Add({val.data(), val.size()});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr,
|
||||||
|
protobuf::RepeatedField<T>* output) {
|
||||||
|
for (const auto& val : attr.getValues<std::complex<T>>()) {
|
||||||
|
output->Add(val.real());
|
||||||
|
output->Add(val.imag());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
|
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
|
||||||
@ -226,139 +244,80 @@ Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
|
|||||||
return InvalidArgument("Unexpected elements attribute type from MLIR.");
|
return InvalidArgument("Unexpected elements attribute type from MLIR.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
// Converts an MLIR elements attribute and adds it to specified repeated field.
|
||||||
// with the double_val field updated.
|
template <typename T>
|
||||||
Status ConvertDoubleElementsAttr(const ElementsAttr attr,
|
void ConvertElementsAttr(const mlir::DenseElementsAttr attr,
|
||||||
TensorProto* output_tensor) {
|
protobuf::RepeatedField<T>* output) {
|
||||||
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
|
if (attr.isSplat()) {
|
||||||
if (elts.isSplat()) {
|
output->Add(attr.getSplatValue<T>());
|
||||||
output_tensor->add_double_val(elts.getSplatValue<double>());
|
|
||||||
} else {
|
} else {
|
||||||
for (auto value : elts.getValues<double>())
|
for (auto value : attr.getValues<T>()) output->Add(value);
|
||||||
output_tensor->add_double_val(value);
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
// Converts an MLIR elements attribute containing half values and adds it to
|
||||||
// with the float_val field updated.
|
// specified repeated field.
|
||||||
Status ConvertFloatElementsAttr(const ElementsAttr attr,
|
void ConvertHalfElementsAttr(const DenseFPElementsAttr attr,
|
||||||
TensorProto* output_tensor) {
|
protobuf::RepeatedField<int>* output_tensor) {
|
||||||
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
|
if (attr.isSplat()) {
|
||||||
if (elts.isSplat()) {
|
output_tensor->Add((*attr.begin()).bitcastToAPInt().getSExtValue());
|
||||||
output_tensor->add_float_val(elts.getSplatValue<float>());
|
|
||||||
} else {
|
} else {
|
||||||
for (auto value : elts.getValues<float>())
|
for (const llvm::APFloat value : attr.getFloatValues())
|
||||||
output_tensor->add_float_val(value);
|
output_tensor->Add(value.bitcastToAPInt().getSExtValue());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
// Converts an MLIR elements attribute containing int values and adds it to
|
||||||
// with the half_val field updated.
|
// specified repeated field.
|
||||||
Status ConvertHalfElementsAttr(const ElementsAttr attr,
|
void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
|
||||||
TensorProto* output_tensor) {
|
protobuf::RepeatedField<int>* output) {
|
||||||
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
|
if (attr.isSplat()) {
|
||||||
if (elts.isSplat()) {
|
output->Add((*attr.begin()).getSExtValue());
|
||||||
output_tensor->add_half_val(
|
|
||||||
(*elts.begin()).bitcastToAPInt().getSExtValue());
|
|
||||||
} else {
|
} else {
|
||||||
for (const auto& value : elts.getFloatValues())
|
for (const llvm::APInt val : attr) output->Add(val.getSExtValue());
|
||||||
output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue());
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr,
|
||||||
// with the int_val field updated.
|
protobuf::RepeatedField<int>* output) {
|
||||||
Status ConvertIntElementsAttr(const mlir::ElementsAttr attr,
|
|
||||||
TensorProto* output_tensor) {
|
|
||||||
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
|
|
||||||
if (elts.isSplat()) {
|
|
||||||
output_tensor->add_int_val((*elts.begin()).getSExtValue());
|
|
||||||
} else {
|
|
||||||
for (const auto& val : elts)
|
|
||||||
output_tensor->add_int_val(val.getSExtValue());
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr,
|
|
||||||
TensorProto* output_tensor) {
|
|
||||||
auto elts = attr.dyn_cast<DenseFPElementsAttr>();
|
|
||||||
if (!elts) {
|
|
||||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bfloat16 is internally represented as `double` in MLIR.
|
// Bfloat16 is internally represented as `double` in MLIR.
|
||||||
if (elts.isSplat()) {
|
if (attr.isSplat()) {
|
||||||
double v = elts.getSplatValue<double>();
|
double v = attr.getSplatValue<double>();
|
||||||
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
||||||
output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
|
output->Add(absl::bit_cast<int16>(bf16_val));
|
||||||
} else {
|
} else {
|
||||||
for (auto v : elts.getValues<double>()) {
|
for (auto v : attr.getValues<double>()) {
|
||||||
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
||||||
output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
|
output->Add(absl::bit_cast<int16>(bf16_val));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
|
||||||
// with the int64_val field updated.
|
|
||||||
Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr,
|
|
||||||
TensorProto* output_tensor) {
|
|
||||||
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
|
|
||||||
if (elts.isSplat()) {
|
|
||||||
output_tensor->add_int64_val((*elts.begin()).getSExtValue());
|
|
||||||
} else {
|
|
||||||
for (const auto& val : elts)
|
|
||||||
output_tensor->add_int64_val(val.getSExtValue());
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
|
||||||
// with bool_val field updated.
|
|
||||||
Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr,
|
|
||||||
TensorProto* output_tensor) {
|
|
||||||
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
|
|
||||||
for (const auto& val : elts) {
|
|
||||||
output_tensor->add_bool_val(val.getBoolValue());
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ConvertToTensorProto(const ElementsAttr attr,
|
|
||||||
TensorProto* output_tensor) {
|
|
||||||
auto type = attr.getType();
|
auto type = attr.getType();
|
||||||
auto shape = type.getShape();
|
auto shape = type.getShape();
|
||||||
DataType output_dtype;
|
DataType output_dtype;
|
||||||
TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
|
TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
|
||||||
output_tensor->set_dtype(output_dtype);
|
output->set_dtype(output_dtype);
|
||||||
ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape());
|
ConvertToTensorShapeProto(shape, output->mutable_tensor_shape());
|
||||||
|
|
||||||
|
if (attr.isa<OpaqueElementsAttr>())
|
||||||
|
return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(), output);
|
||||||
|
|
||||||
|
auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
|
||||||
|
if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr");
|
||||||
|
|
||||||
switch (output_dtype) {
|
switch (output_dtype) {
|
||||||
case DT_FLOAT:
|
case DT_FLOAT:
|
||||||
return ConvertFloatElementsAttr(attr, output_tensor);
|
ConvertElementsAttr<float>(dense_attr, output->mutable_float_val());
|
||||||
|
break;
|
||||||
case DT_HALF:
|
case DT_HALF:
|
||||||
// Handles both DenseFPElementsAttr and OpaqueElementsAttr.
|
ConvertHalfElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
|
||||||
return ConvertHalfElementsAttr(attr, output_tensor);
|
output->mutable_half_val());
|
||||||
|
break;
|
||||||
case DT_DOUBLE:
|
case DT_DOUBLE:
|
||||||
return ConvertDoubleElementsAttr(attr, output_tensor);
|
ConvertElementsAttr(dense_attr, output->mutable_double_val());
|
||||||
|
break;
|
||||||
case DT_QUINT8:
|
case DT_QUINT8:
|
||||||
case DT_UINT8:
|
case DT_UINT8:
|
||||||
case DT_INT8:
|
case DT_INT8:
|
||||||
@ -366,20 +325,40 @@ Status ConvertToTensorProto(const ElementsAttr attr,
|
|||||||
case DT_UINT16:
|
case DT_UINT16:
|
||||||
case DT_INT16:
|
case DT_INT16:
|
||||||
case DT_INT32:
|
case DT_INT32:
|
||||||
return ConvertIntElementsAttr(attr, output_tensor);
|
ConvertIntElementsAttr(dense_attr.cast<DenseIntElementsAttr>(),
|
||||||
|
output->mutable_int_val());
|
||||||
|
break;
|
||||||
|
case DT_UINT32:
|
||||||
|
ConvertElementsAttr(dense_attr, output->mutable_uint32_val());
|
||||||
|
break;
|
||||||
|
case DT_UINT64:
|
||||||
|
ConvertElementsAttr(dense_attr, output->mutable_uint64_val());
|
||||||
|
break;
|
||||||
case DT_INT64:
|
case DT_INT64:
|
||||||
return ConvertInt64ElementsAttr(attr, output_tensor);
|
ConvertElementsAttr(dense_attr, output->mutable_int64_val());
|
||||||
|
break;
|
||||||
case DT_BOOL:
|
case DT_BOOL:
|
||||||
return ConvertBoolElementsAttr(attr, output_tensor);
|
ConvertElementsAttr(dense_attr, output->mutable_bool_val());
|
||||||
|
break;
|
||||||
case DT_BFLOAT16:
|
case DT_BFLOAT16:
|
||||||
return ConvertBfloat16ElementsAttr(attr, output_tensor);
|
ConvertBfloat16ElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
|
||||||
|
output->mutable_half_val());
|
||||||
|
break;
|
||||||
case DT_STRING:
|
case DT_STRING:
|
||||||
return ConvertStringElementsAttr(attr.cast<DenseStringElementsAttr>(),
|
ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(),
|
||||||
output_tensor);
|
output->mutable_string_val());
|
||||||
|
break;
|
||||||
|
case DT_COMPLEX64:
|
||||||
|
ConvertComplexElementsAttr(dense_attr, output->mutable_scomplex_val());
|
||||||
|
break;
|
||||||
|
case DT_COMPLEX128:
|
||||||
|
ConvertComplexElementsAttr(dense_attr, output->mutable_dcomplex_val());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(),
|
return errors::Unimplemented(absl::StrCat("Unimplemented data type ",
|
||||||
output_tensor);
|
DataTypeString(output_dtype)));
|
||||||
}
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {
|
Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||||
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <initializer_list>
|
||||||
|
|
||||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
@ -99,48 +100,74 @@ TEST(ConvertTypeToTensorTypeTest, ConvertStringTensor) {
|
|||||||
EXPECT_EQ(string_values[3], mlir::StringRef("four"));
|
EXPECT_EQ(string_values[3], mlir::StringRef("four"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ConvertTypeToTensorTypeTest, Convert16BitFloats) {
|
class ConvertTensorTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
template <typename T>
|
||||||
|
void VerifyConversion(std::initializer_list<T> values, DataType dtype,
|
||||||
|
mlir::Type expected_ty) {
|
||||||
|
mlir::Builder b(expected_ty.getContext());
|
||||||
|
Tensor tensor(dtype, TensorShape({static_cast<int64>(values.size())}));
|
||||||
|
tensor.flat<T>().setValues(values);
|
||||||
|
|
||||||
|
auto value_or = ConvertTensor(tensor, &b);
|
||||||
|
TF_ASSERT_OK(value_or.status());
|
||||||
|
auto attr = value_or.ValueOrDie();
|
||||||
|
|
||||||
|
EXPECT_EQ(attr.getType().getElementType(), expected_ty);
|
||||||
|
|
||||||
|
Tensor out;
|
||||||
|
TF_ASSERT_OK(ConvertToTensor(attr, &out));
|
||||||
|
|
||||||
|
test::ExpectTensorEqual<T>(tensor, out);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ConvertTensorTest, Simple) {
|
||||||
RegisterDialects();
|
RegisterDialects();
|
||||||
|
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
mlir::Builder b(&context);
|
ASSERT_NO_FATAL_FAILURE(VerifyConversion<Eigen::half>(
|
||||||
|
{Eigen::half(1.0)}, DT_HALF, mlir::FloatType::getF16(&context)));
|
||||||
|
ASSERT_NO_FATAL_FAILURE(
|
||||||
|
VerifyConversion<bfloat16>({bfloat16(1.0), bfloat16(-1.0)}, DT_BFLOAT16,
|
||||||
|
mlir::FloatType::getBF16(&context)));
|
||||||
|
ASSERT_NO_FATAL_FAILURE(VerifyConversion<float>(
|
||||||
|
{1.0, -1.0}, DT_FLOAT, mlir::FloatType::getF32(&context)));
|
||||||
|
ASSERT_NO_FATAL_FAILURE(VerifyConversion<double>(
|
||||||
|
{1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context)));
|
||||||
|
|
||||||
{
|
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int8>(
|
||||||
// Create the sample tensor to convert.
|
{1, -1}, DT_INT8, mlir::IntegerType::get(8, &context)));
|
||||||
Tensor tensor(DT_HALF, TensorShape({1}));
|
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int16>(
|
||||||
auto Tt = tensor.flat<Eigen::half>();
|
{1, -1}, DT_INT16, mlir::IntegerType::get(16, &context)));
|
||||||
Tt.setValues({Eigen::half(1.0)});
|
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int32>(
|
||||||
|
{1, -1}, DT_INT32, mlir::IntegerType::get(32, &context)));
|
||||||
|
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int64>(
|
||||||
|
{1, -1}, DT_INT64, mlir::IntegerType::get(64, &context)));
|
||||||
|
|
||||||
auto value_or = ConvertTensor(tensor, &b);
|
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint8>(
|
||||||
TF_EXPECT_OK(value_or.status());
|
{1, 2}, DT_UINT8,
|
||||||
auto attr = value_or.ValueOrDie();
|
mlir::IntegerType::get(
|
||||||
|
8, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||||
|
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint16>(
|
||||||
|
{1, 2}, DT_UINT16,
|
||||||
|
mlir::IntegerType::get(
|
||||||
|
16, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||||
|
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint32>(
|
||||||
|
{1, 2}, DT_UINT32,
|
||||||
|
mlir::IntegerType::get(
|
||||||
|
32, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||||
|
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint64>(
|
||||||
|
{1, 2}, DT_UINT64,
|
||||||
|
mlir::IntegerType::get(
|
||||||
|
64, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||||
|
|
||||||
EXPECT_TRUE(attr.isa<mlir::DenseFPElementsAttr>());
|
ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<float>>(
|
||||||
EXPECT_TRUE(attr.getType().getElementType().isF16());
|
{{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX64,
|
||||||
|
mlir::ComplexType::get(mlir::FloatType::getF32(&context))));
|
||||||
Tensor out;
|
ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<double>>(
|
||||||
TF_ASSERT_OK(ConvertToTensor(attr, &out));
|
{{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX128,
|
||||||
|
mlir::ComplexType::get(mlir::FloatType::getF64(&context))));
|
||||||
test::ExpectTensorEqual<Eigen::half>(tensor, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
// Create the sample tensor to convert.
|
|
||||||
Tensor tensor(DT_BFLOAT16, TensorShape({2}));
|
|
||||||
auto Tt = tensor.flat<bfloat16>();
|
|
||||||
Tt.setValues({bfloat16(1.0), bfloat16(-1.0)});
|
|
||||||
|
|
||||||
auto value_or = ConvertTensor(tensor, &b);
|
|
||||||
TF_EXPECT_OK(value_or.status());
|
|
||||||
auto attr = value_or.ValueOrDie();
|
|
||||||
|
|
||||||
EXPECT_TRUE(attr.isa<mlir::DenseFPElementsAttr>());
|
|
||||||
EXPECT_TRUE(attr.getType().getElementType().isBF16());
|
|
||||||
|
|
||||||
Tensor out;
|
|
||||||
TF_ASSERT_OK(ConvertToTensor(attr, &out));
|
|
||||||
|
|
||||||
test::ExpectTensorEqual<bfloat16>(tensor, out);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -59,6 +59,18 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
// static TensorFlow op prefix set.
|
||||||
|
std::set<std::string>* GlobalOpPrefixes() {
|
||||||
|
static std::set<std::string>* global_op_prefixes = [] {
|
||||||
|
std::set<std::string>* result = new std::set<std::string>;
|
||||||
|
result->insert("tf.");
|
||||||
|
result->insert("_tf.");
|
||||||
|
result->insert("tf_executor.");
|
||||||
|
return result;
|
||||||
|
}();
|
||||||
|
return global_op_prefixes;
|
||||||
|
}
|
||||||
|
|
||||||
// Converts a location to the debug information for the node def.
|
// Converts a location to the debug information for the node def.
|
||||||
Status ConvertLocation(mlir::Location inst_loc,
|
Status ConvertLocation(mlir::Location inst_loc,
|
||||||
NodeDef::ExperimentalDebugInfo* debug_info) {
|
NodeDef::ExperimentalDebugInfo* debug_info) {
|
||||||
@ -268,8 +280,10 @@ StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef op_name) {
|
|||||||
// - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
|
// - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
|
||||||
// don't need to consider ".source"/".Source" because the nodes with this
|
// don't need to consider ".source"/".Source" because the nodes with this
|
||||||
// suffix are skipped by the caller and will not be added to the graph.
|
// suffix are skipped by the caller and will not be added to the graph.
|
||||||
if (!op_name.consume_front("_tf.") && !op_name.consume_front("tf.") &&
|
auto prefixes = GlobalOpPrefixes();
|
||||||
!op_name.consume_front("tf_executor.")) {
|
if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) {
|
||||||
|
return op_name.consume_front(prefix);
|
||||||
|
})) {
|
||||||
return errors::FailedPrecondition("op node '", op_name.str(),
|
return errors::FailedPrecondition("op node '", op_name.str(),
|
||||||
"' was not a TF op!");
|
"' was not a TF op!");
|
||||||
}
|
}
|
||||||
@ -506,4 +520,9 @@ bool IsLegacyCallInstruction(mlir::Operation* inst) {
|
|||||||
inst->getName().getStringRef().compare("_tf.LegacyCall") == 0;
|
inst->getName().getStringRef().compare("_tf.LegacyCall") == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status AddTensorFlowOpPrefix(std::string prefix) {
|
||||||
|
GlobalOpPrefixes()->insert(prefix);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -34,10 +34,17 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class ShapedType;
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
using stream_executor::port::StatusOr;
|
using stream_executor::port::StatusOr;
|
||||||
|
|
||||||
|
// Add custom op prefix for TensorFlow dialects.
|
||||||
|
Status AddTensorFlowOpPrefix(std::string);
|
||||||
|
|
||||||
// Maps an MLIR op name in the TensorFlow dialect or the TensorFlow control
|
// Maps an MLIR op name in the TensorFlow dialect or the TensorFlow control
|
||||||
// dialect back into a TensorFlow valid op name.
|
// dialect back into a TensorFlow valid op name.
|
||||||
StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef);
|
StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef);
|
||||||
|
@ -164,12 +164,19 @@ std::string GetTPUCompilationDevice(Device system_device) {
|
|||||||
return DeviceNameUtils::ParsedNameToString(system_device);
|
return DeviceNameUtils::ParsedNameToString(system_device);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Finds the host CPU device for a given TPU device.
|
||||||
|
std::string GetCPUHostDeviceForTPUDevice(Device tpu_device) {
|
||||||
|
tpu_device.type = DEVICE_CPU;
|
||||||
|
tpu_device.id = 0;
|
||||||
|
return DeviceNameUtils::ParsedNameToString(tpu_device);
|
||||||
|
}
|
||||||
|
|
||||||
// Determines execution devices when topology and device assignment are not
|
// Determines execution devices when topology and device assignment are not
|
||||||
// defined. This is a special case where a single core computation is replicated
|
// defined. This is a special case where a single core computation is replicated
|
||||||
// to every core in the mesh. TPU devices are simply added to
|
// to every core in the mesh. TPU devices are simply added to
|
||||||
// `execution_devices` of one replica. `num_replicas` must be 1 or the total
|
// `execution_devices` of one replica. `num_replicas` must be 1 or the total
|
||||||
// number of TPU devices available, and `num_cores_per_replica` must be 1.
|
// number of TPU devices available, and `num_cores_per_replica` must be 1.
|
||||||
StatusOr<ExecutionDevices> GetFullMeshTPUExecutionDeviceAssignment(
|
StatusOr<TPUDevicesAndHosts> GetFullMeshTPUExecutionDeviceAssignment(
|
||||||
int num_replicas, int num_cores_per_replica,
|
int num_replicas, int num_cores_per_replica,
|
||||||
llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices) {
|
llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices) {
|
||||||
const int num_tasks = tpu_devices.size();
|
const int num_tasks = tpu_devices.size();
|
||||||
@ -185,17 +192,18 @@ StatusOr<ExecutionDevices> GetFullMeshTPUExecutionDeviceAssignment(
|
|||||||
"'num_cores_per_replica' must be equal to 1, got ",
|
"'num_cores_per_replica' must be equal to 1, got ",
|
||||||
num_cores_per_replica);
|
num_cores_per_replica);
|
||||||
|
|
||||||
ExecutionDevices execution_devices;
|
TPUDevicesAndHosts devices_and_hosts;
|
||||||
execution_devices.reserve(num_replicas);
|
devices_and_hosts.reserve(num_replicas);
|
||||||
for (int i = 0; i < num_replicas; ++i) {
|
for (int i = 0; i < num_replicas; ++i) {
|
||||||
const int task = i / num_tpus_per_task;
|
const int task = i / num_tpus_per_task;
|
||||||
const int device = i % num_tpus_per_task;
|
const int device = i % num_tpus_per_task;
|
||||||
execution_devices.push_back(
|
const auto& tpu_device = tpu_devices[task][device];
|
||||||
{tensorflow::DeviceNameUtils::ParsedNameToString(
|
devices_and_hosts.push_back({TPUDeviceAndHost(
|
||||||
tpu_devices[task][device])});
|
/*device=*/tensorflow::DeviceNameUtils::ParsedNameToString(tpu_device),
|
||||||
|
/*host=*/GetCPUHostDeviceForTPUDevice(tpu_device))});
|
||||||
}
|
}
|
||||||
|
|
||||||
return execution_devices;
|
return devices_and_hosts;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper struct for keeping track of task and device for an associated TPU
|
// Helper struct for keeping track of task and device for an associated TPU
|
||||||
@ -326,7 +334,7 @@ StatusOr<xla::Array4D<TaskAndDevice>> ParseTopologyAttr(
|
|||||||
// - number of device coordinates (in tuple 3) match number 'num_replicas' *
|
// - number of device coordinates (in tuple 3) match number 'num_replicas' *
|
||||||
// 'num_cores_per_replica'
|
// 'num_cores_per_replica'
|
||||||
// - a TPU device associated with each device coordinate
|
// - a TPU device associated with each device coordinate
|
||||||
StatusOr<std::pair<ExecutionDevices, xla::DeviceAssignmentProto>>
|
StatusOr<std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>>
|
||||||
GetGeneralTPUExecutionDeviceAssignment(
|
GetGeneralTPUExecutionDeviceAssignment(
|
||||||
int num_replicas, int num_cores_per_replica,
|
int num_replicas, int num_cores_per_replica,
|
||||||
llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices,
|
llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices,
|
||||||
@ -361,9 +369,9 @@ GetGeneralTPUExecutionDeviceAssignment(
|
|||||||
std::vector<bool> used_device_ids(
|
std::vector<bool> used_device_ids(
|
||||||
location_to_id(bound_x - 1, bound_y - 1, bound_z - 1, bound_core - 1),
|
location_to_id(bound_x - 1, bound_y - 1, bound_z - 1, bound_core - 1),
|
||||||
false);
|
false);
|
||||||
ExecutionDevices execution_devices(
|
TPUDevicesAndHosts devices_and_hosts(
|
||||||
num_replicas,
|
num_replicas, llvm::SmallVector<TPUDeviceAndHost, 8>(
|
||||||
llvm::SmallVector<std::string, 8>(num_cores_per_replica, ""));
|
num_cores_per_replica, TPUDeviceAndHost()));
|
||||||
xla::DeviceAssignment device_assignment(num_replicas, num_cores_per_replica);
|
xla::DeviceAssignment device_assignment(num_replicas, num_cores_per_replica);
|
||||||
int pos = 0;
|
int pos = 0;
|
||||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||||
@ -393,16 +401,18 @@ GetGeneralTPUExecutionDeviceAssignment(
|
|||||||
|
|
||||||
used_device_ids[device_id] = true;
|
used_device_ids[device_id] = true;
|
||||||
device_assignment(replica, logical_core) = device_id;
|
device_assignment(replica, logical_core) = device_id;
|
||||||
execution_devices[replica][logical_core] =
|
auto& device_and_host = devices_and_hosts[replica][logical_core];
|
||||||
DeviceNameUtils::ParsedNameToString(tpu_devices[task][device]);
|
const auto& tpu_device = tpu_devices[task][device];
|
||||||
|
device_and_host.device = DeviceNameUtils::ParsedNameToString(tpu_device);
|
||||||
|
device_and_host.host = GetCPUHostDeviceForTPUDevice(tpu_device);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::DeviceAssignmentProto device_assignment_proto;
|
xla::DeviceAssignmentProto device_assignment_proto;
|
||||||
TF_RETURN_IF_ERROR(device_assignment.Serialize(&device_assignment_proto));
|
TF_RETURN_IF_ERROR(device_assignment.Serialize(&device_assignment_proto));
|
||||||
|
|
||||||
return std::pair<ExecutionDevices, xla::DeviceAssignmentProto>(
|
return std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>(
|
||||||
std::move(execution_devices), std::move(device_assignment_proto));
|
std::move(devices_and_hosts), std::move(device_assignment_proto));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
@ -30,29 +30,40 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
using stream_executor::port::StatusOr;
|
using stream_executor::port::StatusOr;
|
||||||
|
|
||||||
// TPU devices to be used for execution (e.g. devices for TPUExecute ops). They
|
// A TPU device for execution alongside its associated host CPU device.
|
||||||
// are ordered by `num_replicas` followed by `num_cores_per_replica`.
|
struct TPUDeviceAndHost {
|
||||||
using ExecutionDevices =
|
TPUDeviceAndHost() {}
|
||||||
llvm::SmallVector<llvm::SmallVector<std::string, 8>, 8>;
|
TPUDeviceAndHost(llvm::StringRef device, llvm::StringRef host)
|
||||||
|
: device(device), host(host) {}
|
||||||
|
|
||||||
// TPU compilation device, execution devices, and optionally execution device
|
std::string device;
|
||||||
// IDs. Execution device IDs are populated if `topology` and `device_assignment`
|
std::string host;
|
||||||
// are provided.
|
};
|
||||||
|
|
||||||
|
// TPU devices to be used for execution (e.g. devices for TPUExecute ops) and
|
||||||
|
// their associated host CPU devices (for outside compilation). They are ordered
|
||||||
|
// by `num_replicas` followed by `num_cores_per_replica`.
|
||||||
|
using TPUDevicesAndHosts =
|
||||||
|
llvm::SmallVector<llvm::SmallVector<TPUDeviceAndHost, 8>, 8>;
|
||||||
|
|
||||||
|
// TPU compilation device, execution and associated host devices, and optionally
|
||||||
|
// execution device IDs. Execution device IDs are populated if `topology` and
|
||||||
|
// `device_assignment` are provided.
|
||||||
struct TPUDeviceAssignment {
|
struct TPUDeviceAssignment {
|
||||||
TPUDeviceAssignment(llvm::StringRef compilation_device,
|
TPUDeviceAssignment(llvm::StringRef compilation_device,
|
||||||
ExecutionDevices&& execution_devices)
|
TPUDevicesAndHosts&& tpu_devices)
|
||||||
: compilation_device(compilation_device),
|
: compilation_device(compilation_device),
|
||||||
execution_devices(std::move(execution_devices)) {}
|
tpu_devices(std::move(tpu_devices)) {}
|
||||||
|
|
||||||
TPUDeviceAssignment(llvm::StringRef compilation_device,
|
TPUDeviceAssignment(llvm::StringRef compilation_device,
|
||||||
ExecutionDevices&& execution_devices,
|
TPUDevicesAndHosts&& tpu_devices,
|
||||||
xla::DeviceAssignmentProto&& xla_device_assignment)
|
xla::DeviceAssignmentProto&& xla_device_assignment)
|
||||||
: compilation_device(compilation_device),
|
: compilation_device(compilation_device),
|
||||||
execution_devices(std::move(execution_devices)),
|
tpu_devices(std::move(tpu_devices)),
|
||||||
xla_device_assignment(std::move(xla_device_assignment)) {}
|
xla_device_assignment(std::move(xla_device_assignment)) {}
|
||||||
|
|
||||||
std::string compilation_device;
|
std::string compilation_device;
|
||||||
ExecutionDevices execution_devices;
|
TPUDevicesAndHosts tpu_devices;
|
||||||
llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
|
llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -323,30 +323,46 @@ TEST(TPURewriteDeviceUtilTest, ValidFullMeshDeviceAssignment) {
|
|||||||
|
|
||||||
TF_ASSERT_OK(status_or.status());
|
TF_ASSERT_OK(status_or.status());
|
||||||
|
|
||||||
auto& tpu_device_assignment = status_or.ValueOrDie();
|
const auto& tpu_device_assignment = status_or.ValueOrDie();
|
||||||
EXPECT_EQ(tpu_device_assignment.compilation_device,
|
EXPECT_EQ(tpu_device_assignment.compilation_device,
|
||||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||||
auto& execution_devices = tpu_device_assignment.execution_devices;
|
const auto& tpu_devices = tpu_device_assignment.tpu_devices;
|
||||||
ASSERT_EQ(execution_devices.size(), 8);
|
ASSERT_EQ(tpu_devices.size(), 8);
|
||||||
for (const auto& replica_execution_device : execution_devices)
|
for (const auto& replica_tpu_devices : tpu_devices)
|
||||||
ASSERT_EQ(replica_execution_device.size(), 1);
|
ASSERT_EQ(replica_tpu_devices.size(), 1);
|
||||||
|
|
||||||
EXPECT_EQ(execution_devices[0][0],
|
EXPECT_EQ(tpu_devices[0][0].device,
|
||||||
"/job:worker/replica:0/task:0/device:TPU:0");
|
"/job:worker/replica:0/task:0/device:TPU:0");
|
||||||
EXPECT_EQ(execution_devices[1][0],
|
EXPECT_EQ(tpu_devices[0][0].host,
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[1][0].device,
|
||||||
"/job:worker/replica:0/task:0/device:TPU:1");
|
"/job:worker/replica:0/task:0/device:TPU:1");
|
||||||
EXPECT_EQ(execution_devices[2][0],
|
EXPECT_EQ(tpu_devices[1][0].host,
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[2][0].device,
|
||||||
"/job:worker/replica:0/task:0/device:TPU:2");
|
"/job:worker/replica:0/task:0/device:TPU:2");
|
||||||
EXPECT_EQ(execution_devices[3][0],
|
EXPECT_EQ(tpu_devices[2][0].host,
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[3][0].device,
|
||||||
"/job:worker/replica:0/task:0/device:TPU:3");
|
"/job:worker/replica:0/task:0/device:TPU:3");
|
||||||
EXPECT_EQ(execution_devices[4][0],
|
EXPECT_EQ(tpu_devices[3][0].host,
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[4][0].device,
|
||||||
"/job:worker/replica:0/task:1/device:TPU:0");
|
"/job:worker/replica:0/task:1/device:TPU:0");
|
||||||
EXPECT_EQ(execution_devices[5][0],
|
EXPECT_EQ(tpu_devices[4][0].host,
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[5][0].device,
|
||||||
"/job:worker/replica:0/task:1/device:TPU:1");
|
"/job:worker/replica:0/task:1/device:TPU:1");
|
||||||
EXPECT_EQ(execution_devices[6][0],
|
EXPECT_EQ(tpu_devices[5][0].host,
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[6][0].device,
|
||||||
"/job:worker/replica:0/task:1/device:TPU:2");
|
"/job:worker/replica:0/task:1/device:TPU:2");
|
||||||
EXPECT_EQ(execution_devices[7][0],
|
EXPECT_EQ(tpu_devices[6][0].host,
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[7][0].device,
|
||||||
"/job:worker/replica:0/task:1/device:TPU:3");
|
"/job:worker/replica:0/task:1/device:TPU:3");
|
||||||
|
EXPECT_EQ(tpu_devices[7][0].host,
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||||
|
|
||||||
EXPECT_FALSE(tpu_device_assignment.xla_device_assignment.hasValue());
|
EXPECT_FALSE(tpu_device_assignment.xla_device_assignment.hasValue());
|
||||||
}
|
}
|
||||||
@ -410,30 +426,46 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) {
|
|||||||
|
|
||||||
TF_ASSERT_OK(status_or.status());
|
TF_ASSERT_OK(status_or.status());
|
||||||
|
|
||||||
auto& tpu_device_assignment = status_or.ValueOrDie();
|
const auto& tpu_device_assignment = status_or.ValueOrDie();
|
||||||
EXPECT_EQ(tpu_device_assignment.compilation_device,
|
EXPECT_EQ(tpu_device_assignment.compilation_device,
|
||||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||||
auto& execution_devices = tpu_device_assignment.execution_devices;
|
const auto& tpu_devices = tpu_device_assignment.tpu_devices;
|
||||||
ASSERT_EQ(execution_devices.size(), 4);
|
ASSERT_EQ(tpu_devices.size(), 4);
|
||||||
for (const auto& replica_execution_device : execution_devices)
|
for (const auto& replica_tpu_devices : tpu_devices)
|
||||||
ASSERT_EQ(replica_execution_device.size(), 2);
|
ASSERT_EQ(replica_tpu_devices.size(), 2);
|
||||||
|
|
||||||
EXPECT_EQ(execution_devices[0][0],
|
EXPECT_EQ(tpu_devices[0][0].device,
|
||||||
"/job:worker/replica:0/task:0/device:TPU:0");
|
"/job:worker/replica:0/task:0/device:TPU:0");
|
||||||
EXPECT_EQ(execution_devices[0][1],
|
EXPECT_EQ(tpu_devices[0][0].host,
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[0][1].device,
|
||||||
"/job:worker/replica:0/task:1/device:TPU:3");
|
"/job:worker/replica:0/task:1/device:TPU:3");
|
||||||
EXPECT_EQ(execution_devices[1][0],
|
EXPECT_EQ(tpu_devices[0][1].host,
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[1][0].device,
|
||||||
"/job:worker/replica:0/task:0/device:TPU:1");
|
"/job:worker/replica:0/task:0/device:TPU:1");
|
||||||
EXPECT_EQ(execution_devices[1][1],
|
EXPECT_EQ(tpu_devices[1][0].host,
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[1][1].device,
|
||||||
"/job:worker/replica:0/task:1/device:TPU:2");
|
"/job:worker/replica:0/task:1/device:TPU:2");
|
||||||
EXPECT_EQ(execution_devices[2][0],
|
EXPECT_EQ(tpu_devices[1][1].host,
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[2][0].device,
|
||||||
"/job:worker/replica:0/task:0/device:TPU:3");
|
"/job:worker/replica:0/task:0/device:TPU:3");
|
||||||
EXPECT_EQ(execution_devices[2][1],
|
EXPECT_EQ(tpu_devices[2][0].host,
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[2][1].device,
|
||||||
"/job:worker/replica:0/task:1/device:TPU:0");
|
"/job:worker/replica:0/task:1/device:TPU:0");
|
||||||
EXPECT_EQ(execution_devices[3][0],
|
EXPECT_EQ(tpu_devices[2][1].host,
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[3][0].device,
|
||||||
"/job:worker/replica:0/task:0/device:TPU:2");
|
"/job:worker/replica:0/task:0/device:TPU:2");
|
||||||
EXPECT_EQ(execution_devices[3][1],
|
EXPECT_EQ(tpu_devices[3][0].host,
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[3][1].device,
|
||||||
"/job:worker/replica:0/task:1/device:TPU:1");
|
"/job:worker/replica:0/task:1/device:TPU:1");
|
||||||
|
EXPECT_EQ(tpu_devices[3][1].host,
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||||
|
|
||||||
auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
|
auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
|
||||||
ASSERT_TRUE(xla_device_assignment.hasValue());
|
ASSERT_TRUE(xla_device_assignment.hasValue());
|
||||||
@ -511,23 +543,35 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
|
|||||||
EXPECT_EQ(tpu_device_assignment.compilation_device,
|
EXPECT_EQ(tpu_device_assignment.compilation_device,
|
||||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||||
|
|
||||||
auto& execution_devices = tpu_device_assignment.execution_devices;
|
auto& tpu_devices = tpu_device_assignment.tpu_devices;
|
||||||
ASSERT_EQ(execution_devices.size(), 2);
|
ASSERT_EQ(tpu_devices.size(), 2);
|
||||||
for (const auto& replica_execution_device : execution_devices)
|
for (const auto& replica_tpu_devices : tpu_devices)
|
||||||
ASSERT_EQ(replica_execution_device.size(), 3);
|
ASSERT_EQ(replica_tpu_devices.size(), 3);
|
||||||
|
|
||||||
EXPECT_EQ(execution_devices[0][0],
|
EXPECT_EQ(tpu_devices[0][0].device,
|
||||||
"/job:worker/replica:0/task:1/device:TPU:1");
|
"/job:worker/replica:0/task:1/device:TPU:1");
|
||||||
EXPECT_EQ(execution_devices[0][1],
|
EXPECT_EQ(tpu_devices[0][0].host,
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[0][1].device,
|
||||||
"/job:worker/replica:0/task:1/device:TPU:0");
|
"/job:worker/replica:0/task:1/device:TPU:0");
|
||||||
EXPECT_EQ(execution_devices[0][2],
|
EXPECT_EQ(tpu_devices[0][1].host,
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[0][2].device,
|
||||||
"/job:worker/replica:0/task:2/device:TPU:0");
|
"/job:worker/replica:0/task:2/device:TPU:0");
|
||||||
EXPECT_EQ(execution_devices[1][0],
|
EXPECT_EQ(tpu_devices[0][2].host,
|
||||||
|
"/job:worker/replica:0/task:2/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[1][0].device,
|
||||||
"/job:worker/replica:0/task:2/device:TPU:1");
|
"/job:worker/replica:0/task:2/device:TPU:1");
|
||||||
EXPECT_EQ(execution_devices[1][1],
|
EXPECT_EQ(tpu_devices[1][0].host,
|
||||||
|
"/job:worker/replica:0/task:2/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[1][1].device,
|
||||||
"/job:worker/replica:0/task:0/device:TPU:0");
|
"/job:worker/replica:0/task:0/device:TPU:0");
|
||||||
EXPECT_EQ(execution_devices[1][2],
|
EXPECT_EQ(tpu_devices[1][1].host,
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||||
|
EXPECT_EQ(tpu_devices[1][2].device,
|
||||||
"/job:worker/replica:0/task:0/device:TPU:1");
|
"/job:worker/replica:0/task:0/device:TPU:1");
|
||||||
|
EXPECT_EQ(tpu_devices[1][2].host,
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||||
|
|
||||||
auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
|
auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
|
||||||
ASSERT_TRUE(xla_device_assignment.hasValue());
|
ASSERT_TRUE(xla_device_assignment.hasValue());
|
||||||
|
@ -104,26 +104,24 @@ int main(int argc, char** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (import_saved_model_object_graph) {
|
std::unordered_set<std::string> tags = absl::StrSplit(saved_model_tags, ',');
|
||||||
std::unordered_set<std::string> tags =
|
std::vector<std::string> exported_names_vector =
|
||||||
absl::StrSplit(saved_model_tags, ',');
|
|
||||||
std::vector<std::string> exported_names =
|
|
||||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||||
|
absl::Span<std::string> exported_names(exported_names_vector);
|
||||||
|
|
||||||
|
if (import_saved_model_object_graph) {
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
|
||||||
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||||
input_filename, tags, absl::Span<std::string>(exported_names),
|
input_filename, tags, exported_names, &context);
|
||||||
&context);
|
|
||||||
if (!module) return 1;
|
if (!module) return 1;
|
||||||
|
|
||||||
module->print(output->os());
|
module->print(output->os());
|
||||||
} else if (import_saved_model_signature_defs) {
|
} else if (import_saved_model_signature_defs) {
|
||||||
std::unordered_set<std::string> tags =
|
|
||||||
absl::StrSplit(saved_model_tags, ',');
|
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
|
|
||||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||||
input_filename, tags, &context);
|
input_filename, tags, exported_names, &context);
|
||||||
if (!module) return 1;
|
if (!module) return 1;
|
||||||
|
|
||||||
module->print(output->os());
|
module->print(output->os());
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
load("//third_party/mlir:tblgen.bzl", "gentbl")
|
load("//third_party/mlir:tblgen.bzl", "gentbl")
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//visibility:public"],
|
default_visibility = ["//visibility:public"],
|
||||||
@ -39,7 +40,7 @@ gentbl(
|
|||||||
"ir/tfjs_ops.td",
|
"ir/tfjs_ops.td",
|
||||||
"@llvm-project//mlir:OpBaseTdFiles",
|
"@llvm-project//mlir:OpBaseTdFiles",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -131,10 +132,106 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
|
||||||
"@llvm-project//mlir:Analysis",
|
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:Transforms",
|
"@llvm-project//mlir:Transforms",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "json_translate_lib",
|
||||||
|
srcs = [
|
||||||
|
"translate/json_translate.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"translate/json_translate.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tensorflow_js",
|
||||||
|
":tensorflow_js_dialect_registration",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:export_utils",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
"@llvm-project//mlir:Translation",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tf_to_tfjs_json",
|
||||||
|
srcs = ["translate/tf_to_tfjs_json.cc"],
|
||||||
|
hdrs = [
|
||||||
|
"translate/tf_to_tfjs_json.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":json_translate_lib",
|
||||||
|
":tfjs_optimize",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/stream_executor/lib",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
"@llvm-project//llvm:support",
|
||||||
|
"@llvm-project//mlir:AllPassesAndDialects",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:Parser",
|
||||||
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_binary(
|
||||||
|
name = "json_translate",
|
||||||
|
deps = [
|
||||||
|
":json_translate_lib",
|
||||||
|
"@llvm-project//mlir:MlirTranslateMain",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "tf_tfjs_translate_main",
|
||||||
|
srcs = [
|
||||||
|
"translate/tf_tfjs_translate.cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_binary(
|
||||||
|
name = "tf_tfjs_translate",
|
||||||
|
srcs = [":tf_tfjs_translate_main"],
|
||||||
|
deps = [
|
||||||
|
":json_translate_lib",
|
||||||
|
":tensorflow_js_passes",
|
||||||
|
":tf_to_tfjs_json",
|
||||||
|
":tfjs_optimize",
|
||||||
|
"//tensorflow/compiler/mlir:init_mlir",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/platform:errors",
|
||||||
|
"//tensorflow/stream_executor/lib",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@llvm-project//llvm:support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace tfjs {
|
namespace tfjs {
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
#define TFJS_DIALECT
|
#define TFJS_DIALECT
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Interfaces/SideEffects.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TensorFlow.js dialect definitions
|
// TensorFlow.js dialect definitions
|
||||||
|
23
tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD
Normal file
23
tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
glob_lit_tests(
|
||||||
|
data = [
|
||||||
|
":test_utilities",
|
||||||
|
],
|
||||||
|
driver = "@llvm-project//mlir:run_lit.sh",
|
||||||
|
test_file_exts = [
|
||||||
|
"pbtxt",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bundle together all of the test utilities that are used by tests.
|
||||||
|
filegroup(
|
||||||
|
name = "test_utilities",
|
||||||
|
testonly = True,
|
||||||
|
data = [
|
||||||
|
"//tensorflow/compiler/mlir/tfjs:tf_tfjs_translate",
|
||||||
|
"@llvm-project//llvm:FileCheck",
|
||||||
|
],
|
||||||
|
)
|
78
tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt
Normal file
78
tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
# RUN: tf_tfjs_translate %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Mul -o - | FileCheck %s --dump-input-on-failure
|
||||||
|
# Add two tensor<4xi32> inputs and return the result
|
||||||
|
|
||||||
|
node {
|
||||||
|
name: "Add"
|
||||||
|
op: "Add"
|
||||||
|
input: "input0"
|
||||||
|
input: "input1"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "input0"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "input1"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Mul"
|
||||||
|
op: "Mul"
|
||||||
|
input: "Add"
|
||||||
|
input: "Add"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
versions {
|
||||||
|
producer: 27
|
||||||
|
}
|
||||||
|
|
||||||
|
# CHECK: "name": "input0"
|
||||||
|
# CHECK-NEXT: "op": "Placeholder"
|
||||||
|
# CHECK: "type": "DT_INT32"
|
||||||
|
# CHECK: "name": "input1",
|
||||||
|
# CHECK-NEXT: "op": "Placeholder"
|
||||||
|
# CHECK: "type": "DT_INT32"
|
||||||
|
# CHECK: "name": "Add"
|
||||||
|
# CHECK-NEXT: "op": "AddV2"
|
||||||
|
# CHECK-NEXT: "input":
|
||||||
|
# CHECK-NEXT: "input0"
|
||||||
|
# CHECK-NEXT: "input1"
|
||||||
|
# CHECK: "type": "DT_INT32"
|
||||||
|
# CHECK: "name": "Mul1"
|
||||||
|
# CHECK-NEXT: "op": "Mul"
|
||||||
|
# CHECK-NEXT: "input":
|
||||||
|
# CHECK-NEXT: "Add"
|
||||||
|
# CHECK-NEXT: "Add"
|
||||||
|
# CHECK: "type": "DT_INT32"
|
||||||
|
# CHECK: "name": "Mul"
|
||||||
|
# CHECK-NEXT: "op": "_Retval"
|
||||||
|
# CHECK-NEXT: "input":
|
||||||
|
# CHECK-NEXT: "Mul1"
|
||||||
|
# CHECK: "type": "DT_INT32"
|
||||||
|
# CHECK: "library"
|
||||||
|
# CHECK: "versions"
|
||||||
|
# CHECK: "producer": 27
|
||||||
|
|
175
tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt
Normal file
175
tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
# RUN: tf_tfjs_translate %s -tf-input-arrays=input0 -tf-input-data-types=DT_FLOAT -tf-input-shapes=10 -tf-output-arrays=Add -tf-custom-opdefs="name: 'Prelu' input_arg: { name: 'x' type: DT_FLOAT } input_arg: { name: 'alpha' type: DT_FLOAT } output_arg: { name: 'c' type: DT_FLOAT }" -o - | FileCheck %s --dump-input-on-failure
|
||||||
|
# Add two tensor<4xi32> inputs and return the result
|
||||||
|
|
||||||
|
node {
|
||||||
|
name: "input0"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shape"
|
||||||
|
value {
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
size: 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "alpha"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_FLOAT
|
||||||
|
tensor_shape {
|
||||||
|
}
|
||||||
|
float_val: 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Relu"
|
||||||
|
op: "Relu"
|
||||||
|
input: "input0"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Neg"
|
||||||
|
op: "Neg"
|
||||||
|
input: "input0"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Relu1"
|
||||||
|
op: "Relu"
|
||||||
|
input: "Neg"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Mul"
|
||||||
|
op: "Mul"
|
||||||
|
input: "alpha"
|
||||||
|
input: "Relu1"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Add"
|
||||||
|
op: "Add"
|
||||||
|
input: "Relu"
|
||||||
|
input: "Mul"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "main"
|
||||||
|
op: "_Retval"
|
||||||
|
input: "Add"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "index"
|
||||||
|
value {
|
||||||
|
i: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
library {
|
||||||
|
}
|
||||||
|
versions {
|
||||||
|
producer: 344
|
||||||
|
}
|
||||||
|
|
||||||
|
# CHECK: "node":
|
||||||
|
# CHECK: "name": "input0",
|
||||||
|
# CHECK-NEXT: "op": "Placeholder",
|
||||||
|
# CHECK-NEXT: "attr":
|
||||||
|
# CHECK: "type": "DT_FLOAT"
|
||||||
|
# CHECK: "name": "Add.Relu.Neg.Relu1.Mul",
|
||||||
|
# CHECK-NEXT: "op": "Const",
|
||||||
|
# CHECK-NEXT: "attr":
|
||||||
|
# CHECK: "value":
|
||||||
|
# CHECK: "tensor":
|
||||||
|
# CHECK: "dtype": "DT_FLOAT",
|
||||||
|
# CHECK: "tensorShape": {},
|
||||||
|
# CHECK: "floatVal":
|
||||||
|
# CHECK: -0.5
|
||||||
|
# CHECK: "name": "Add.Relu.Neg.Relu1.Mul1",
|
||||||
|
# CHECK-NEXT: "op": "Prelu",
|
||||||
|
# CHECK-NEXT: "input":
|
||||||
|
# CHECK: "input0",
|
||||||
|
# CHECK: "Add.Relu.Neg.Relu1.Mul"
|
||||||
|
# CHECK: "attr":
|
||||||
|
# CHECK: "_output_shapes":
|
||||||
|
# CHECK: "list":
|
||||||
|
# CHECK: "shape":
|
||||||
|
# CHECK: "dim":
|
||||||
|
# CHECK: "size": "10"
|
||||||
|
# CHECK: "experimentalDebugInfo": {}
|
||||||
|
# CHECK: "name": "Add",
|
||||||
|
# CHECK-NEXT: "op": "_Retval",
|
||||||
|
# CHECK-NEXT: "input":
|
||||||
|
# CHECK: "Add.Relu.Neg.Relu1.Mul1"
|
||||||
|
# CHECK: "attr":
|
||||||
|
# CHECK: "T":
|
||||||
|
# CHECK: "type": "DT_FLOAT"
|
||||||
|
# CHECK: "library": {},
|
||||||
|
# CHECK: "versions":
|
||||||
|
# CHECK: "producer": 344
|
||||||
|
|
@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
|
||||||
|
|
||||||
@ -47,6 +46,11 @@ void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {
|
|||||||
// Canonicalize, CSE etc.
|
// Canonicalize, CSE etc.
|
||||||
pm->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
pm->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||||
pm->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
|
pm->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
|
||||||
|
|
||||||
|
// raise to executor dialect in order to use GraphDef converter
|
||||||
|
pm->addNestedPass<mlir::FuncOp>(
|
||||||
|
mlir::CreateFunctionalToExecutorDialectConversionPass());
|
||||||
|
pm->addNestedPass<mlir::FuncOp>(mlir::CreateBreakUpIslandsPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
105
tensorflow/compiler/mlir/tfjs/translate/json_translate.cc
Normal file
105
tensorflow/compiler/mlir/tfjs/translate/json_translate.cc
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
|
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||||
|
#include "mlir/Translation.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
|
using mlir::ModuleOp;
|
||||||
|
using mlir::TranslateFromMLIRRegistration;
|
||||||
|
using std::string;
|
||||||
|
using tensorflow::Status;
|
||||||
|
using xla::StatusOr;
|
||||||
|
|
||||||
|
// Translates the given MLIR module in the TFJS dialect to TFJS JSON
|
||||||
|
// format. Returns false on success.
|
||||||
|
//
|
||||||
|
bool tfjs::MlirToJSONTranslateFunction(ModuleOp module,
|
||||||
|
std::string* serialized_json) {
|
||||||
|
string json_output;
|
||||||
|
// Allow TF to treat TFJS ops as TF ops.
|
||||||
|
if (!tensorflow::AddTensorFlowOpPrefix("tfjs.").ok()) {
|
||||||
|
LOG(ERROR) << "Failed to add tfjs op prefix.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
tensorflow::GraphExportConfig confs;
|
||||||
|
confs.export_shapes = true;
|
||||||
|
confs.export_library = true;
|
||||||
|
tensorflow::FunctionLibraryDefinition flib_def(
|
||||||
|
tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
|
||||||
|
absl::flat_hash_set<tensorflow::Node*> control_ret_nodes;
|
||||||
|
auto graph = absl::make_unique<tensorflow::Graph>(flib_def);
|
||||||
|
auto status = tensorflow::ConvertMlirToGraph(module, confs, &graph, &flib_def,
|
||||||
|
&control_ret_nodes);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << "Graph export failed: " << status;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto graphdef = absl::make_unique<tensorflow::GraphDef>();
|
||||||
|
graph->ToGraphDef(graphdef.get());
|
||||||
|
|
||||||
|
// Replace the _Arg nodes of the main function with Placeholder op.
|
||||||
|
auto nodes = graphdef->mutable_node();
|
||||||
|
for (const auto& node : llvm::enumerate(*nodes)) {
|
||||||
|
if (node.value().op() == "_Arg") {
|
||||||
|
nodes->Mutable(node.index())->set_op("Placeholder");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::protobuf::util::JsonPrintOptions json_options;
|
||||||
|
json_options.add_whitespace = true;
|
||||||
|
auto jsonStatus = tensorflow::protobuf::util::MessageToJsonString(
|
||||||
|
*graphdef, &json_output, json_options);
|
||||||
|
if (!jsonStatus.ok()) {
|
||||||
|
LOG(ERROR) << "Proto2Json failed: " << status;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*serialized_json = std::move(json_output);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static mlir::LogicalResult MlirToJSONFileTranslateFunction(
|
||||||
|
ModuleOp module, llvm::raw_ostream& output) {
|
||||||
|
std::string serialized_json;
|
||||||
|
if (!tfjs::MlirToJSONTranslateFunction(module, &serialized_json))
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
output << serialized_json;
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static TranslateFromMLIRRegistration MLIRToJSONFileTranslate(
|
||||||
|
"mlir-to-tfjs-json", MlirToJSONFileTranslateFunction);
|
31
tensorflow/compiler/mlir/tfjs/translate/json_translate.h
Normal file
31
tensorflow/compiler/mlir/tfjs/translate/json_translate.h
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace tfjs {
|
||||||
|
|
||||||
|
// Translates the given MLIR `module` into a JSON string. Returns true if
|
||||||
|
// translation fails, otherwise returns false.
|
||||||
|
bool MlirToJSONTranslateFunction(mlir::ModuleOp module,
|
||||||
|
std::string* serialized_json);
|
||||||
|
} // namespace tfjs
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
|
173
tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc
Normal file
173
tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/strings/str_split.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/Support/InitLLVM.h"
|
||||||
|
#include "llvm/Support/SourceMgr.h"
|
||||||
|
#include "llvm/Support/ToolOutputFile.h"
|
||||||
|
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Function.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
|
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
|
|
||||||
|
using llvm::cl::opt;
|
||||||
|
using mlir::MLIRContext;
|
||||||
|
using stream_executor::port::StatusOr;
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
opt<std::string> input_file_name(llvm::cl::Positional,
|
||||||
|
llvm::cl::desc("<input file>"),
|
||||||
|
llvm::cl::init("-"));
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
opt<bool> import_saved_model_object_graph(
|
||||||
|
"savedmodel-objectgraph-to-mlir",
|
||||||
|
llvm::cl::desc("Import a saved model to its MLIR representation"),
|
||||||
|
llvm::cl::value_desc("dir"));
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
opt<bool> import_saved_model_signature_defs(
|
||||||
|
"savedmodel-signaturedefs-to-mlir",
|
||||||
|
llvm::cl::desc("Import a saved model V1 to its MLIR representation"),
|
||||||
|
llvm::cl::value_desc("dir"));
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
opt<std::string> saved_model_tags(
|
||||||
|
"tf-savedmodel-tags",
|
||||||
|
llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, "
|
||||||
|
"separated by ','"),
|
||||||
|
llvm::cl::init("serve"));
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
opt<std::string> saved_model_exported_names(
|
||||||
|
"tf-savedmodel-exported-names",
|
||||||
|
llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty "
|
||||||
|
"(the default) means export all."),
|
||||||
|
llvm::cl::init(""));
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
opt<std::string> output_file_name("o", llvm::cl::desc("<output file>"),
|
||||||
|
llvm::cl::value_desc("filename"),
|
||||||
|
llvm::cl::init("-"));
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
opt<bool> input_mlir(
|
||||||
|
"input-mlir",
|
||||||
|
llvm::cl::desc("Take input TensorFlow model in textual MLIR instead of "
|
||||||
|
"GraphDef format"),
|
||||||
|
llvm::cl::init(false), llvm::cl::Hidden);
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
opt<bool> output_mlir(
|
||||||
|
"output-mlir",
|
||||||
|
llvm::cl::desc("Output MLIR rather than JSON for the generated TFJS model"),
|
||||||
|
llvm::cl::init(false));
|
||||||
|
|
||||||
|
// The following approach allows injecting opdefs in addition
|
||||||
|
// to those that are already part of the global TF registry to be linked in
|
||||||
|
// prior to importing the graph. The primary goal is for support of custom ops.
|
||||||
|
// This is not intended to be a general solution for custom ops for the future
|
||||||
|
// but mainly for supporting older models like mobilenet_ssd. More appropriate
|
||||||
|
// mechanisms, such as op hints or using functions to represent composable ops
|
||||||
|
// like https://github.com/tensorflow/community/pull/113 should be encouraged
|
||||||
|
// going forward.
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
llvm::cl::list<std::string> custom_opdefs(
|
||||||
|
"tf-custom-opdefs", llvm::cl::desc("List of custom opdefs when importing "
|
||||||
|
"graphdef"));
|
||||||
|
|
||||||
|
// Debugging flag to print function mapping in the JSON.
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static opt<bool> print_function_result_mapping(
|
||||||
|
"print-function-result-mapping",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"Print the mapping of function result to json output buffer"),
|
||||||
|
llvm::cl::init(false));
|
||||||
|
|
||||||
|
enum TranslationStatus { kTrSuccess, kTrFailure };
|
||||||
|
|
||||||
|
static int PrintFunctionResultMapping(const std::string& result) {
|
||||||
|
std::cout << result << std::endl;
|
||||||
|
return kTrSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
tensorflow::InitMlir y(&argc, &argv);
|
||||||
|
|
||||||
|
llvm::cl::ParseCommandLineOptions(argc, argv,
|
||||||
|
"TF GraphDef to TFJS JSON converter\n");
|
||||||
|
|
||||||
|
MLIRContext context;
|
||||||
|
llvm::SourceMgr source_mgr;
|
||||||
|
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
|
||||||
|
|
||||||
|
StatusOr<mlir::OwningModuleRef> module;
|
||||||
|
|
||||||
|
if (import_saved_model_object_graph || import_saved_model_signature_defs) {
|
||||||
|
if (input_mlir)
|
||||||
|
module = tensorflow::errors::InvalidArgument(
|
||||||
|
"Importing saved model should not have input_mlir set");
|
||||||
|
module = tensorflow::ImportSavedModel(
|
||||||
|
import_saved_model_object_graph, import_saved_model_signature_defs,
|
||||||
|
custom_opdefs, input_file_name, saved_model_tags,
|
||||||
|
saved_model_exported_names, &context);
|
||||||
|
} else {
|
||||||
|
module = tensorflow::LoadFromGraphdefOrMlirSource(
|
||||||
|
input_file_name, input_mlir, custom_opdefs, debug_info_file,
|
||||||
|
input_arrays, input_dtypes, input_shapes, output_arrays,
|
||||||
|
/*prune_unused_nodes=*/true, &source_mgr, &context);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If errors occur, the library call in the above already logged the error
|
||||||
|
// message. So we can just return here.
|
||||||
|
if (!module.ok()) return kTrFailure;
|
||||||
|
|
||||||
|
mlir::PassManager pm(&context);
|
||||||
|
|
||||||
|
tensorflow::AddTFToTFJSConversionPasses(&pm);
|
||||||
|
|
||||||
|
std::string result;
|
||||||
|
auto status = tensorflow::ConvertTFOpsToTfjsJSON(module.ValueOrDie().get(),
|
||||||
|
output_mlir, &result, &pm);
|
||||||
|
if (!status.ok()) return kTrFailure;
|
||||||
|
|
||||||
|
std::string error_msg;
|
||||||
|
auto output = mlir::openOutputFile(output_file_name, &error_msg);
|
||||||
|
if (output == nullptr) {
|
||||||
|
llvm::errs() << error_msg << '\n';
|
||||||
|
return kTrFailure;
|
||||||
|
}
|
||||||
|
output->os() << result;
|
||||||
|
output->keep();
|
||||||
|
|
||||||
|
// Print out debugging info related to function mapping.
|
||||||
|
if (print_function_result_mapping) return PrintFunctionResultMapping(result);
|
||||||
|
return kTrSuccess;
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user