Merge branch 'master' into sriniva2/threadpool_build
This commit is contained in:
commit
3692d790f8
2
.bazelrc
2
.bazelrc
@ -168,6 +168,8 @@ build:cuda_clang --action_env TF_CUDA_CLANG=1
|
||||
build:dbg --config=opt -c dbg
|
||||
# for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360
|
||||
build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
|
||||
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
|
||||
build:dbg --copt -DDEBUG_BUILD
|
||||
|
||||
build:tensorrt --action_env TF_NEED_TENSORRT=1
|
||||
|
||||
|
144
RELEASE.md
144
RELEASE.md
@ -1,3 +1,147 @@
|
||||
# Release 2.2.0
|
||||
|
||||
TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update).
|
||||
|
||||
Coinciding with this change, new releases of [TensorFlow's Docker images](https://hub.docker.com/r/tensorflow/tensorflow/) provide Python 3 exclusively. Because all images now use Python 3, Docker tags containing `-py3` will no longer be provided and existing `-py3` tags like `latest-py3` will not be updated.
|
||||
|
||||
## Major Features and Improvements
|
||||
|
||||
* Replaced the scalar type for string tensors from `std::string` to `tensorflow::tstring` which is now ABI stable.
|
||||
* A new Profiler for TF 2 for CPU/GPU/TPU. It offers both device and host performance analysis, including input pipeline and TF Ops. Optimization advisory is provided whenever possible. Please see [this tutorial](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras) and [guide](https://www.tensorflow.org/guide/profiler) for usage guidelines.
|
||||
* Export C++ functions to Python using `pybind11` as opposed to `SWIG` as a part of our [deprecation of swig efforts](https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md).
|
||||
* `tf.distribute`:
|
||||
* Support added for global sync `BatchNormalization` by using the newly added `tf.keras.layers.experimental.SyncBatchNormalization` layer. This layer will sync `BatchNormalization` statistics every step across all replicas taking part in sync training.
|
||||
* Performance improvements for GPU multi-worker distributed training using `tf.distribute.experimental.MultiWorkerMirroredStrategy`
|
||||
* Update NVIDIA `NCCL` to `2.5.7-1` for better performance and performance tuning. Please see [nccl developer guide](https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html) for more information on this.
|
||||
* Support gradient `allreduce` in `float16`. See this [example](https://github.com/tensorflow/models/blob/master/official/staging/training/grad_utils.py) usage.
|
||||
* Experimental support of [all reduce gradient packing](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/CollectiveHints) to allow overlapping gradient aggregation with backward path computation.
|
||||
* Deprecated `experimental_run_v2` method for distribution strategies and renamed the method `run` as it is no longer experimental.
|
||||
* Add CompositeTensor support for DistributedIterators. This should help prevent unnecessary function retracing and memory leaks.
|
||||
* `tf.keras`:
|
||||
* `Model.fit` major improvements:
|
||||
* You can now use custom training logic with `Model.fit` by overriding `Model.train_step`.
|
||||
* Easily write state-of-the-art training loops without worrying about all of the features `Model.fit` handles for you (distribution strategies, callbacks, data formats, looping logic, etc)
|
||||
* See the default [`Model.train_step`](https://github.com/tensorflow/tensorflow/blob/1381fc8e15e22402417b98e3881dfd409998daea/tensorflow/python/keras/engine/training.py#L540) for an example of what this function should look like. Same applies for validation and inference via `Model.test_step` and `Model.predict_step`.
|
||||
* SavedModel uses its own `Model._saved_model_inputs_spec` attr now instead of
|
||||
relying on `Model.inputs` and `Model.input_names`, which are no longer set for subclass Models.
|
||||
This attr is set in eager, `tf.function`, and graph modes. This gets rid of the need for users to
|
||||
manually call `Model._set_inputs` when using Custom Training Loops(CTLs).
|
||||
* Dynamic shapes are supported for generators by calling the Model on the first batch we "peek" from the generator.
|
||||
This used to happen implicitly in `Model._standardize_user_data`. Long-term, a solution where the
|
||||
`DataAdapter` doesn't need to call the Model is probably preferable.
|
||||
* The SavedModel format now supports all Keras built-in layers (including metrics, preprocessing layers, and stateful RNN layers)
|
||||
* Update Keras batch normalization layer to use the running mean and average computation in the `fused_batch_norm`. You should see significant performance improvements when using `fused_batch_norm` in Eager mode.
|
||||
|
||||
* `tf.lite`:
|
||||
* Enable TFLite experimental new converter by default.
|
||||
* XLA
|
||||
* XLA now builds and works on windows. All prebuilt packages come with XLA available.
|
||||
* XLA can be [enabled for a `tf.function`](https://www.tensorflow.org/xla#explicit_compilation_with_tffunction
|
||||
) with “compile or throw exception” semantics on CPU and GPU.
|
||||
|
||||
## Breaking Changes
|
||||
* `tf.keras`:
|
||||
* In `tf.keras.applications` the name of the "top" layer has been standardized to "predictions". This is only a problem if your code relies on the exact name of the layer.
|
||||
* Huber loss function has been updated to be consistent with other Keras losses. It now computes mean over the last axis of per-sample losses before applying the reduction function.
|
||||
* AutoGraph no longer converts functions passed to `tf.py_function`, `tf.py_func` and `tf.numpy_function`.
|
||||
* Deprecating `XLA_CPU` and `XLA_GPU` devices with this release.
|
||||
* Increasing the minimum bazel version to build TF to 2.0.0 to use Bazel's `cc_experimental_shared_library`.
|
||||
* Keras compile/fit behavior for functional and subclassed models have been unified. Model properties such as `metrics`, `metrics_names` will now be available only after **training/evaluating the model on actual data** for functional models. `metrics` will **now include** model `loss` and output losses.`loss_functions` property has been removed from the model. This was an undocumented property that was accidentally public and has now been removed.
|
||||
|
||||
## Known Caveats
|
||||
* The current TensorFlow release now **requires** [gast](https://pypi.org/project/gast/) version 0.3.3.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* `tf.data`:
|
||||
* Removed `autotune_algorithm` from experimental optimization options.
|
||||
* TF Core:
|
||||
* `tf.constant` always creates CPU tensors irrespective of the current device context.
|
||||
* Eager `TensorHandles` maintain a list of mirrors for any copies to local or remote devices. This avoids any redundant copies due to op execution.
|
||||
* For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer experimental and is available as simply `.ref()`.
|
||||
* `pfor/vectorized_map`: Added support for vectorizing 56 more ops. Vectorizing `tf.cond` is also supported now.
|
||||
* Set as much partial shape as we can infer statically within the gradient impl of the gather op.
|
||||
* Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body functions are stateless. This allows multiple gradients while ops to run in parallel under distribution strategy.
|
||||
* Speed up `GradientTape` in eager mode by auto-generating list of op inputs/outputs which are unused and hence not cached for gradient functions.
|
||||
* Support `back_prop=False` in `while_v2` but mark it as deprecated.
|
||||
* Improve error message when attempting to use `None` in data-dependent control flow.
|
||||
* Add `RaggedTensor.numpy()`.
|
||||
* Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow indexing into uniform dimensions.
|
||||
* Update `tf.expand_dims` to always insert the new dimension as a non-ragged dimension.
|
||||
* Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm` when `ids` is ragged.
|
||||
* Allow `batch_dims==rank(indices)` in `tf.gather`.
|
||||
* Add support for bfloat16 in `tf.print`.
|
||||
* `tf.distribute`:
|
||||
* Support `embedding_column` with variable-length input features for `MultiWorkerMirroredStrategy`.
|
||||
* `tf.keras`:
|
||||
* Added `experimental_aggregate_gradients` argument to `tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom gradient aggregation and processing aggregated gradients in custom training loop.
|
||||
* Allow `pathlib.Path` paths for loading models via Keras API.
|
||||
* `tf.function`/AutoGraph:
|
||||
* AutoGraph is now available in `ReplicaContext.merge_call`, `Strategy.extended.update` and `Strategy.extended.update_non_slot`.
|
||||
* Experimental support for shape invariants has been enabled in `tf.function`. See the API docs for `tf.autograph.experimental.set_loop_options` for additonal info.
|
||||
* AutoGraph error messages now exclude frames corresponding to APIs internal to AutoGraph.
|
||||
* Improve shape inference for `tf.function` input arguments to unlock more Grappler optimizations in TensorFlow 2.x.
|
||||
* Improve automatic control dependency management of resources by allowing resource reads to occur in parallel and synchronizing only on writes.
|
||||
* Fix execution order of multiple stateful calls to `experimental_run_v2` in `tf.function`.
|
||||
* You can now iterate over `RaggedTensors` using a for loop inside `tf.function`.
|
||||
* `tf.lite`:
|
||||
* Migrated the `tf.lite` C inference API out of experimental into lite/c.
|
||||
* Add an option to disallow `NNAPI` CPU / partial acceleration on Android 10
|
||||
* TFLite Android AARs now include the C headers and APIs are required to use TFLite from native code.
|
||||
* Refactors the delegate and delegate kernel sources to allow usage in the linter.
|
||||
* Limit delegated ops to actually supported ones if a device name is specified or `NNAPI` CPU Fallback is disabled.
|
||||
* TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`.
|
||||
* TFLite's unpack op now supports boolean tensor inputs.
|
||||
* Microcontroller and embedded code moved from experimental to main TensorFlow Lite folder
|
||||
* Check for large TFLite tensors.
|
||||
* Fix GPU delegate crash with C++17.
|
||||
* Add 5D support to TFLite `strided_slice`.
|
||||
* Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to be accelerated.
|
||||
* Fix segmentation fault when running a model with LSTM nodes using `NNAPI` Delegate
|
||||
* Fix `NNAPI` delegate failure when an operand for Maximum/Minimum operation is a scalar.
|
||||
* Fix `NNAPI` delegate failure when Axis input for reduce operation is a scalar.
|
||||
* Expose option to limit the number of partitions that will be delegated to `NNAPI`.
|
||||
* If a target accelerator is specified, use its feature level to determine operations to delegate instead of SDK version.
|
||||
* `tf.random`:
|
||||
* Various random number generation improvements:
|
||||
* Add a fast path for default `random_uniform`
|
||||
* `random_seed` documentation improvement.
|
||||
* `RandomBinomial` broadcasts and appends the sample shape to the left rather than the right.
|
||||
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`, `tf.random.stateless_poisson`
|
||||
* `tf.random.stateless_uniform` now supports unbounded sampling of `int` types.
|
||||
* Math and Linear Algebra:
|
||||
* Add `tf.linalg.LinearOperatorTridiag`.
|
||||
* Add `LinearOperatorBlockLowerTriangular`
|
||||
* Add broadcasting support to tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204), tf.math.invert_permutation.
|
||||
* Add `tf.math.sobol_sample` op.
|
||||
* Add `tf.math.xlog1py`.
|
||||
* Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`.
|
||||
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to `tf.signal`.
|
||||
* TPU Enhancements:
|
||||
* Refactor `TpuClusterResolver` to move shared logic to a separate pip package.
|
||||
* Support configuring TPU software version from cloud tpu client.
|
||||
* Allowed TPU embedding weight decay factor to be multiplied by learning rate.
|
||||
* XLA Support:
|
||||
* Add standalone XLA AOT runtime target + relevant .cc sources to pip package.
|
||||
* Add check for memory alignment to MemoryAllocation::MemoryAllocation() on 32-bit ARM. This ensures a deterministic early exit instead of a hard to debug bus error later.
|
||||
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to XLA header+object files and include them in your C++ programs.
|
||||
* Enable `Igamma`, `Igammac` for XLA.
|
||||
* Deterministic Op Functionality:
|
||||
* XLA reduction emitter is deterministic when the environment variable `TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends deterministic `tf.nn.bias_add` back-prop functionality (and therefore also deterministic back-prop of bias-addition in Keras layers) to include when XLA JIT complilation is enabled.
|
||||
* Fix problem, when running on a CUDA GPU and when either environment variable `TF_DETERMINSTIC_OPS` or environment variable `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer configurations led to an exception with the message "No algorithm worked!"
|
||||
* Tracing and Debugging:
|
||||
* Add source, destination name to `_send` traceme to allow easier debugging.
|
||||
* Add traceme event to `fastpathexecute`.
|
||||
* Other:
|
||||
* Fix an issue with AUC.reset_states for multi-label AUC [#35852](https://github.com/tensorflow/tensorflow/issues/35852)
|
||||
* Fix the TF upgrade script to not delete files when there is a parsing error and the output mode is `in-place`.
|
||||
* Move `tensorflow/core:framework/*_pyclif` rules to `tensorflow/core/framework:*_pyclif`.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
||||
372046933, 8bitmp3, aaronhma, Abin Shahab, Aditya Patwardhan, Agoniii, Ahti Kitsik, Alan Yee, Albin Joy, Alex Hoffman, Alexander Grund, Alexandre E. Eichenberger, Amit Kumar Jaiswal, amoitra, Andrew Anderson, Angus-Luo, Anthony Barbier, Anton Kachatkou, Anuj Rawat, archis, Arpan-Dhatt, Arvind Sundararajan, Ashutosh Hathidara, autoih, Bairen Yi, Balint Cristian, Bas Aarts, BashirSbaiti, Basit Ayantunde, Ben Barsdell, Benjamin Gaillard, boron, Brett Koonce, Bryan Cutler, Christian Goll, Christian Sachs, Clayne Robison, comet, Daniel Falbel, Daria Zhuravleva, darsh8200, David Truby, Dayananda-V, deepakm, Denis Khalikov, Devansh Singh, Dheeraj R Reddy, Diederik Van Liere, Diego Caballero, Dominic Jack, dothinking, Douman, Drake Gens, Duncan Riach, Ehsan Toosi, ekuznetsov139, Elena Zhelezina, elzino, Ending2015a, Eric Schweitz, Erik Zettel, Ethan Saadia, Eugene Kuznetsov, Evgeniy Zheltonozhskiy, Ewout Ter Hoeven, exfalso, FAIJUL, Fangjun Kuang, Fei Hu, Frank Laub, Frederic Bastien, Fredrik Knutsson, frreiss, Frédéric Rechtenstein, fsx950223, Gaurav Singh, gbaned, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, Hans Gaiser, Hans Pabst, Haoyu Wu, Harry Slatyer, hsahovic, Hugo, Hugo Sjöberg, IrinaM21, jacco, Jake Tae, Jean-Denis Lesage, Jean-Michel Gorius, Jeff Daily, Jens Elofsson, Jerry Shih, jerryyin, Jin Mingjian, Jinjing Zhou, JKIsaacLee, jojimonv, Jonathan Dekhtiar, Jose Ignacio Gomez, Joseph-Rance, Judd, Julian Gross, Kaixi Hou, Kaustubh Maske Patil, Keunwoo Choi, Kevin Hanselman, Khor Chean Wei, Kilaru Yasaswi Sri Chandra Gandhi, Koan-Sin Tan, Koki Ibukuro, Kristian Holsheimer, kurileo, Lakshay Tokas, Lee Netherton, leike666666, Leslie-Fang-Intel, Li, Guizi, LIUJIAN435, Lukas Geiger, Lyo Nguyen, madisetti, Maher Jendoubi, Mahmoud Abuzaina, Manuel Freiberger, Marcel Koester, Marco Jacopo Ferrarotti, Markus Franke, marload, Mbah-Javis, mbhuiyan, Meng Zhang, Michael Liao, MichaelKonobeev, Michal Tarnowski, Milan Straka, minoring, Mohamed Nour Abouelseoud, MoussaMM, Mrinal Jain, mrTsjolder, Måns Nilsson, Namrata Bhave, Nicholas Gao, Niels Ole Salscheider, nikochiko, Niranjan Hasabnis, Nishidha Panpaliya, nmostafa, Noah Trenaman, nuka137, Officium, Owen L - Sfe, Pallavi G, Paul Andrey, Peng Sun, Peng Wu, Phil Pearl, PhilipMay, pingsutw, Pooya Davoodi, PragmaTwice, pshiko, Qwerty71, R Gomathi, Rahul Huilgol, Richard Xiao, Rick Wierenga, Roberto Rosmaninho, ruchit2801, Rushabh Vasani, Sami, Sana Damani, Sarvesh Dubey, Sasan Jafarnejad, Sergii Khomenko, Shane Smiskol, Shaochen Shi, sharkdtu, Shawn Presser, ShengYang1, Shreyash Patodia, Shyam Sundar Dhanabalan, Siju Samuel, Somyajit Chakraborty Sam, Srihari Humbarwadi, srinivasan.narayanamoorthy, Srishti Yadav, Steph-En-M, Stephan Uphoff, Stephen Mugisha, SumanSudhir, Taehun Kim, Tamas Bela Feher, TengLu, Tetragramm, Thierry Herrmann, Tian Jin, tigertang, Tom Carchrae, Tom Forbes, Trent Lo, Victor Peng, vijayphoenix, Vincent Abriou, Vishal Bhola, Vishnuvardhan Janapati, vladbataev, VoVAllen, Wallyss Lima, Wen-Heng (Jack) Chung, wenxizhu, William D. Irons, William Zhang, Xiaoming (Jason) Cui, Xiaoquan Kong, Xinan Jiang, Yasir Modak, Yasuhiro Matsumoto, Yaxun (Sam) Liu, Yong Tang, Ytyt-Yt, yuan, Yuan Mingshuai, Yuan Tang, Yuki Ueda, Yusup, zhangshijin, zhuwenxi
|
||||
|
||||
# Release 2.0.1
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
36
configure.py
36
configure.py
@ -144,7 +144,7 @@ def write_to_bazelrc(line):
|
||||
|
||||
|
||||
def write_action_env_to_bazelrc(var_name, var):
|
||||
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
|
||||
write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var)))
|
||||
|
||||
|
||||
def run_shell(cmd, allow_non_zero=False, stderr=None):
|
||||
@ -205,7 +205,7 @@ def setup_python(environ_cp):
|
||||
# Get PYTHON_BIN_PATH, default is the current running python.
|
||||
default_python_bin_path = sys.executable
|
||||
ask_python_bin_path = ('Please specify the location of python. [Default is '
|
||||
'%s]: ') % default_python_bin_path
|
||||
'{}]: ').format(default_python_bin_path)
|
||||
while True:
|
||||
python_bin_path = get_from_env_or_user_or_default(environ_cp,
|
||||
'PYTHON_BIN_PATH',
|
||||
@ -215,9 +215,10 @@ def setup_python(environ_cp):
|
||||
if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
|
||||
break
|
||||
elif not os.path.exists(python_bin_path):
|
||||
print('Invalid python path: %s cannot be found.' % python_bin_path)
|
||||
print('Invalid python path: {} cannot be found.'.format(python_bin_path))
|
||||
else:
|
||||
print('%s is not executable. Is it the python binary?' % python_bin_path)
|
||||
print('{} is not executable. Is it the python binary?'.format(
|
||||
python_bin_path))
|
||||
environ_cp['PYTHON_BIN_PATH'] = ''
|
||||
|
||||
# Convert python path to Windows style before checking lib and version
|
||||
@ -236,7 +237,7 @@ def setup_python(environ_cp):
|
||||
default_python_lib_path = python_lib_paths[0]
|
||||
python_lib_path = get_input(
|
||||
'Please input the desired Python library path to use. '
|
||||
'Default is [%s]\n' % python_lib_paths[0])
|
||||
'Default is [{}]\n'.format(python_lib_paths[0]))
|
||||
if not python_lib_path:
|
||||
python_lib_path = default_python_lib_path
|
||||
environ_cp['PYTHON_LIB_PATH'] = python_lib_path
|
||||
@ -252,7 +253,7 @@ def setup_python(environ_cp):
|
||||
# Set-up env variables used by python_configure.bzl
|
||||
write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
|
||||
write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
|
||||
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
|
||||
write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path))
|
||||
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
|
||||
|
||||
# If choosen python_lib_path is from a path specified in the PYTHONPATH
|
||||
@ -266,7 +267,7 @@ def setup_python(environ_cp):
|
||||
with open(
|
||||
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
|
||||
'w') as f:
|
||||
f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
|
||||
f.write('export PYTHON_BIN_PATH="{}"'.format(python_bin_path))
|
||||
|
||||
|
||||
def reset_tf_configure_bazelrc():
|
||||
@ -320,11 +321,12 @@ def get_var(environ_cp,
|
||||
Raise the error to avoid infinitely looping.
|
||||
"""
|
||||
if not question:
|
||||
question = 'Do you wish to build TensorFlow with %s support?' % query_item
|
||||
question = 'Do you wish to build TensorFlow with {} support?'.format(
|
||||
query_item)
|
||||
if not yes_reply:
|
||||
yes_reply = '%s support will be enabled for TensorFlow.' % query_item
|
||||
yes_reply = '{} support will be enabled for TensorFlow.'.format(query_item)
|
||||
if not no_reply:
|
||||
no_reply = 'No %s' % yes_reply
|
||||
no_reply = 'No {}'.format(yes_reply)
|
||||
|
||||
yes_reply += '\n'
|
||||
no_reply += '\n'
|
||||
@ -368,7 +370,7 @@ def get_var(environ_cp,
|
||||
print(no_reply)
|
||||
var = False
|
||||
else:
|
||||
print('Invalid selection: %s' % user_input_origin)
|
||||
print('Invalid selection: {}'.format(user_input_origin))
|
||||
return var
|
||||
|
||||
|
||||
@ -479,13 +481,13 @@ def check_bazel_version(min_version, max_version):
|
||||
if which('bazel') is None:
|
||||
print('Cannot find bazel. Please install bazel.')
|
||||
sys.exit(0)
|
||||
curr_version = run_shell(
|
||||
['bazel', '--batch', '--bazelrc=/dev/null', 'version'])
|
||||
|
||||
for line in curr_version.split('\n'):
|
||||
if 'Build label: ' in line:
|
||||
curr_version = line.split('Build label: ')[1]
|
||||
break
|
||||
stderr = open(os.devnull, 'wb')
|
||||
curr_version = run_shell(['bazel', '--version'],
|
||||
allow_non_zero = True,
|
||||
stderr = stderr)
|
||||
if curr_version.startswith('bazel '):
|
||||
curr_version = curr_version.split('bazel ')[1]
|
||||
|
||||
min_version_int = convert_version_to_int(min_version)
|
||||
curr_version_int = convert_version_to_int(curr_version)
|
||||
|
@ -517,6 +517,7 @@ package_group(
|
||||
"//perftools/accelerators/xprof/api/...",
|
||||
"//third_party/py/autograph/...",
|
||||
"//third_party/swift/tensorflow/x10/...",
|
||||
"//third_party/swift/tensorflow_apis/...",
|
||||
"//tensorflow/...",
|
||||
"//tensorflow_estimator/python/estimator/...",
|
||||
"//tensorflow_models/official/...",
|
||||
@ -529,6 +530,13 @@ package_group(name = "ndarray_tensor_allow_list")
|
||||
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
||||
package_group(name = "composite_tensor_whitelist")
|
||||
|
||||
# Packages that use private types symbols, until they are exported.
|
||||
# TODO(b/154650521) Remove.
|
||||
package_group(
|
||||
name = "types_whitelist",
|
||||
packages = ["//learning/deepmind/tensorflow/replicator/..."],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "intel_binary_blob",
|
||||
data = if_mkl_ml(
|
||||
|
@ -16,7 +16,6 @@ load(
|
||||
"//tensorflow/core/platform:build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -609,7 +608,6 @@ filegroup(
|
||||
],
|
||||
exclude = [
|
||||
"c_api_experimental.cc",
|
||||
"*c_api_tfrt*",
|
||||
"*test*",
|
||||
"*dlpack*",
|
||||
],
|
||||
|
@ -38,7 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
#include "tensorflow/c/eager/c_api_tfrt.h"
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
@ -924,7 +924,7 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||
context->GetDevicePlacementPolicy());
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
||||
tensorflow::Tensor tensor;
|
||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
// placed in memory of different devices or remote address spaces.
|
||||
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
|
||||
TF_Status* status);
|
||||
// Indicates that the caller will not be using `h` any more.
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
|
||||
|
@ -50,6 +50,13 @@ tensorflow::ServerDef GetServerDef(int num_tasks) {
|
||||
return GetServerDef("localhost", num_tasks);
|
||||
}
|
||||
|
||||
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
|
||||
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->at(task_index) =
|
||||
tensorflow::strings::StrCat("localhost:", port);
|
||||
}
|
||||
|
||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||
const std::vector<float>& expected_values) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
@ -101,6 +108,22 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
// Read the value of variable `var` and save it into `out_value`.
|
||||
void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var,
|
||||
TFE_TensorHandle** out_value) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op, var, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, out_value, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
@ -243,6 +266,102 @@ TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
|
||||
TestRemoteExecuteUpdateServerDef(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
|
||||
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
|
||||
EXPECT_NE(var_handle0, nullptr);
|
||||
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
|
||||
EXPECT_NE(var_handle1, nullptr);
|
||||
|
||||
TFE_TensorHandle* value_handle = nullptr;
|
||||
ReadVariable(ctx, var_handle1, &value_handle);
|
||||
CheckTFE_TensorHandleHasFloats(value_handle, {2});
|
||||
TFE_DeleteTensorHandle(value_handle);
|
||||
|
||||
// Start a new worker to replace task:1
|
||||
ReplaceTaskInServerDef(&server_def, 1);
|
||||
server_def.set_task_index(1);
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
// Update server def to replace the remote device with the device info on the
|
||||
// new worker (different incarnation ID).
|
||||
server_def.set_task_index(0);
|
||||
string serialized_update = server_def.SerializeAsString();
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||
serialized_update.size(), status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// The device of var_handle0 is local device which is the same before and
|
||||
// after cluster update. Remove resource with valid device should succeed.
|
||||
TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle0, status);
|
||||
TFE_OpSetDevice(op, dev0_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
|
||||
// The device of var_handle1 is remote device, which was replaced during
|
||||
// cluster update. Removing resource with invalid device should fail
|
||||
// gracefully (i.e., with error status) instead of crashing with segfaults.
|
||||
op = TFE_NewOp(ctx, "DestroyResourceOp", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(op, var_handle1, status);
|
||||
TFE_OpSetDevice(op, dev1_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(op);
|
||||
|
||||
TFE_DeleteTensorHandle(var_handle0);
|
||||
TFE_DeleteTensorHandle(var_handle1);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) {
|
||||
TestRemoteExecuteUpdateServerDefResourceAccess(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) {
|
||||
TestRemoteExecuteUpdateServerDefResourceAccess(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
||||
// when updating cluster with non-exsitent worker
|
||||
@ -282,6 +401,7 @@ void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{2, tensorflow::strings::StrCat("localhost:", port)});
|
||||
server_def.set_task_index(0);
|
||||
string serialized_update = server_def.SerializeAsString();
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||
serialized_update.size(), status);
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
@ -638,3 +639,35 @@ TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
|
||||
return tensorflow::wrap(
|
||||
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
||||
TFE_TensorHandle** handles,
|
||||
int* num_handles,
|
||||
TF_Status* status) {
|
||||
std::vector<tensorflow::TensorHandle*> tensor_handles;
|
||||
tensor_handles.reserve(*num_handles);
|
||||
for (int i = 0; i < *num_handles; ++i) {
|
||||
tensor_handles.push_back(
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i])));
|
||||
}
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
status->status = tensorflow::TensorHandle::CreatePackedHandle(
|
||||
std::move(tensor_handles), context, &handle);
|
||||
return tensorflow::wrap(handle);
|
||||
}
|
||||
|
||||
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetAllowSoftPlacement(enable);
|
||||
}
|
||||
|
||||
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetLogDevicePlacement(enable);
|
||||
}
|
||||
|
@ -541,6 +541,26 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
|
||||
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
|
||||
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
|
||||
|
||||
// Create a packed TensorHandle with the given list of TensorHandles.
|
||||
// If `handles` are on the same device, assign the same device to the packed
|
||||
// handle; if `handles` are on different deivces, assign a CompositeDevice to
|
||||
// it.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
|
||||
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
|
||||
TF_Status* status);
|
||||
|
||||
// Configure soft device placement policy for the eager executor. Note this
|
||||
// policy is applied to any subsequent op executions.
|
||||
TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status);
|
||||
|
||||
// Configure device placement policy logging for the eager executor. Note this
|
||||
// policy is applied to any subsequent op executions.
|
||||
TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
|
||||
unsigned char enable,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -351,6 +351,192 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
|
||||
/*heavy_load_on_streaming_rpc=*/true);
|
||||
}
|
||||
|
||||
// Add the values of three variables on three different tasks.
|
||||
string AddVariablesFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'AddVariablesFunction'"
|
||||
" input_arg {"
|
||||
" name: 'var'"
|
||||
" type: DT_RESOURCE"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'sum'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read0'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read1'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read2'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var'"
|
||||
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add1'"
|
||||
" op: 'Add'"
|
||||
" input: 'read0:value:0'"
|
||||
" input: 'read1:value:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add2'"
|
||||
" op: 'Add'"
|
||||
" input: 'add1:z:0'"
|
||||
" input: 'read2:value:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'sum'"
|
||||
" value: 'add2:z:0'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
TEST(CAPI, TestFunctionWithPackedInput) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
// Create one variable per task.
|
||||
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
|
||||
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
|
||||
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
|
||||
|
||||
// Pack 3 variable handles into one TFE_TensorHandle.
|
||||
int num_replicas = 3;
|
||||
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
|
||||
TFE_TensorHandle* packed_handle =
|
||||
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
|
||||
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
|
||||
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
|
||||
|
||||
const string composite_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
|
||||
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
|
||||
composite_device_name);
|
||||
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
|
||||
composite_device_name);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// Register and run a function which returns the sum of 3 variables.
|
||||
const string function_def = AddVariablesFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(func, packed_handle, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TFE_DeleteOp(func);
|
||||
TFE_DeleteTensorHandle(packed_handle);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
float sum = 0;
|
||||
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(sum, 6.0);
|
||||
|
||||
TFE_DeleteTensorHandle(h0);
|
||||
TFE_DeleteTensorHandle(h1);
|
||||
TFE_DeleteTensorHandle(h2);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
|
@ -1132,51 +1132,6 @@ void BM_ExecuteFunction(int iters, int async) {
|
||||
}
|
||||
BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1);
|
||||
|
||||
TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
|
||||
TF_Status* status) {
|
||||
// Create the variable handle.
|
||||
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
|
||||
TFE_OpSetAttrString(op, "container", "", 0);
|
||||
TFE_OpSetAttrString(op, "shared_name", "", 0);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &var_handle, &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(1, num_retvals);
|
||||
|
||||
// Assign 'value' to it.
|
||||
op = TFE_NewOp(ctx, "AssignVariableOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
|
||||
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
|
||||
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
|
||||
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
||||
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
value_handle(TFE_NewTensorHandle(t.get(), status),
|
||||
TFE_DeleteTensorHandle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, value_handle.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(0, num_retvals);
|
||||
|
||||
return var_handle;
|
||||
}
|
||||
|
||||
TEST(CAPI, Variables) {
|
||||
// Variables use resource handles, so this is really a test for resource
|
||||
// tensor handling.
|
||||
@ -1186,7 +1141,7 @@ TEST(CAPI, Variables) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 12.0);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
@ -1227,7 +1182,7 @@ void BM_ReadVariable(int iters) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
|
||||
@ -1248,6 +1203,8 @@ void BM_ReadVariable(int iters) {
|
||||
CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
h = nullptr;
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
tensorflow::testing::StopTiming();
|
||||
TFE_DeleteOp(op);
|
||||
|
@ -133,6 +133,58 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||
const tensorflow::string& device_name) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
// Create the variable handle.
|
||||
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
|
||||
TFE_OpSetAttrString(op, "container", "", 0);
|
||||
TFE_OpSetAttrString(op, "shared_name", "", 0);
|
||||
if (!device_name.empty()) {
|
||||
TFE_OpSetDevice(op, device_name.c_str(), status);
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &var_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(1, num_retvals);
|
||||
|
||||
// Assign 'value' to it.
|
||||
op = TFE_NewOp(ctx, "AssignVariableOp", status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op, var_handle, status);
|
||||
|
||||
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
|
||||
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
|
||||
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
|
||||
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
value_handle(TFE_NewTensorHandle(t.get(), status),
|
||||
TFE_DeleteTensorHandle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, value_handle.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
num_retvals = 0;
|
||||
TFE_Execute(op, nullptr, &num_retvals, status);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
CHECK_EQ(0, num_retvals);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
return var_handle;
|
||||
}
|
||||
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
|
@ -42,6 +42,11 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx);
|
||||
// Return a tensor handle containing a 3x2 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
|
||||
|
||||
// Return a variable handle referring to a variable with the given initial value
|
||||
// on the given device.
|
||||
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
|
||||
const tensorflow::string& device_name = "");
|
||||
|
||||
// Return an add op multiplying `a` by `b`.
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||
|
||||
|
@ -29,7 +29,7 @@ using tensorflow::string;
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(UnifedCAPI, TestBasicEager) {
|
||||
TEST(UnifiedCAPI, TestBasicEager) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -81,7 +81,7 @@ TEST(UnifedCAPI, TestBasicEager) {
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestBasicGraph) {
|
||||
TEST(UnifiedCAPI, TestBasicGraph) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
@ -131,6 +131,7 @@ TEST(UnifedCAPI, TestBasicGraph) {
|
||||
string fn_name = "double";
|
||||
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
||||
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_DeleteAbstractTensor(placeholder_t);
|
||||
TF_DeleteAbstractTensor(output_t);
|
||||
|
||||
@ -184,7 +185,7 @@ TEST(UnifedCAPI, TestBasicGraph) {
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
@ -200,7 +201,7 @@ TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
@ -221,7 +222,7 @@ TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
@ -242,7 +243,7 @@ TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
// Build an Eager context.
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -288,7 +289,7 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
|
@ -59,6 +59,20 @@ class AbstractContextInterface {
|
||||
virtual AbstractTensorInterface* CreateTensor(
|
||||
DataType dtype, absl::Span<const int64> dim_sizes) = 0;
|
||||
|
||||
typedef void (*MemoryReleaser)(void* data, size_t len, void* arg);
|
||||
|
||||
// Create a tensor instance from the given data buffer and description.
|
||||
// `memory_releaser` will be called on destruction, and it's responsible for
|
||||
// cleaning up the underlying buffer. `convert_string` indicates whether it
|
||||
// has to handle tstring conversion. Expected to be removed once tstring
|
||||
// migration is done.
|
||||
virtual AbstractTensorInterface* CreateTensor(DataType dtype,
|
||||
const int64_t* dims,
|
||||
int num_dims, void* data,
|
||||
size_t len, bool convert_string,
|
||||
MemoryReleaser memory_releaser,
|
||||
void* memory_releaser_arg) = 0;
|
||||
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
AbstractTensorInterface* t) = 0;
|
||||
|
@ -27,6 +27,7 @@ cc_library(
|
||||
name = "parallel_device",
|
||||
srcs = [":sources"],
|
||||
hdrs = [":headers"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
@ -43,6 +44,7 @@ tf_cc_test(
|
||||
srcs = ["parallel_device_test.cc"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
":parallel_device_ops",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
@ -52,3 +54,19 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# Note: ParallelDevice-specific ops are experimental and not currently linked in
|
||||
# to TensorFlow by default, just used in a few tests.
|
||||
filegroup(
|
||||
name = "parallel_device_ops_srcs",
|
||||
srcs = ["parallel_device_ops.cc"],
|
||||
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device_ops",
|
||||
srcs = [":parallel_device_ops_srcs"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = ["//tensorflow/core:framework"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -92,6 +92,10 @@ class ParallelDevice {
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) const;
|
||||
|
||||
// A parallel tensor with scalar integers numbering component devices.
|
||||
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
|
||||
TF_Status* status) const;
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors (or
|
||||
@ -208,6 +212,46 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
status);
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
TFE_Context* context, TF_Status* status) const {
|
||||
// TODO(allenl): We could cache DeviceIDs (keyed by context).
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
int64_t* device_id = new int64_t;
|
||||
*device_id = device_index;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(
|
||||
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||
sizeof(int64_t),
|
||||
[](void* data, size_t, void* arg) {
|
||||
delete reinterpret_cast<int64_t*>(data);
|
||||
},
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
// TODO(allenl): Here and when executing regular operations, we could hold
|
||||
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
|
||||
// device names repeatedly.
|
||||
OpPtr const_op(TFE_NewOp(context, "Const", status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
||||
TFE_TensorHandle* device_handle;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(device_handle);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
@ -282,6 +326,13 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
}
|
||||
result.emplace(std::move(outputs));
|
||||
return result;
|
||||
} else if (operation_name == std::string("DeviceID")) {
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(1);
|
||||
result_content.push_back(DeviceIDs(context, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
maybe_parallel_results(
|
||||
|
26
tensorflow/c/eager/parallel_device/parallel_device_ops.cc
Normal file
26
tensorflow/c/eager/parallel_device/parallel_device_ops.cc
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
// TODO(allenl): Figure out if we need this op, and if so whether we should move
|
||||
// it to core TF. Right now the eager C API does some checking of op
|
||||
// registrations before calling into custom devices, but we may be able to avoid
|
||||
// that.
|
||||
REGISTER_OP("DeviceID")
|
||||
.Output("device_id: int64")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(tensorflow::shape_inference::ScalarShape);
|
@ -278,14 +278,15 @@ TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
||||
}
|
||||
|
||||
// Assert that `handle` is equal to `expected_value`.
|
||||
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
|
||||
template <typename value_type>
|
||||
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(expected_value,
|
||||
*static_cast<float*>(TF_TensorData(value_zero.get())));
|
||||
EXPECT_EQ(expected_value,
|
||||
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
|
||||
}
|
||||
|
||||
template <std::size_t num_devices>
|
||||
@ -343,8 +344,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 20.);
|
||||
AssertScalarFloatEq(components[1].get(), 20.);
|
||||
ExpectScalarEq<float>(components[0].get(), 20.);
|
||||
ExpectScalarEq<float>(components[1].get(), 20.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
@ -373,8 +374,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 23.);
|
||||
AssertScalarFloatEq(components[1].get(), 18.);
|
||||
ExpectScalarEq<float>(components[0].get(), 23.);
|
||||
ExpectScalarEq<float>(components[1].get(), 18.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
@ -383,6 +384,32 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
// Compute the device ID twice and verify the result
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, result_handle, &components, status.get());
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<int64_t>(components[0].get(), 0);
|
||||
ExpectScalarEq<int64_t>(components[1].get(), 1);
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestBasicCPU) {
|
||||
@ -498,8 +525,8 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// The value of the original tensor is replicated on each device.
|
||||
AssertScalarFloatEq(components[0].get(), 3.);
|
||||
AssertScalarFloatEq(components[1].get(), 3.);
|
||||
ExpectScalarEq<float>(components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(components[1].get(), 3.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device =
|
||||
@ -630,7 +657,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
&second_components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(second_components[1].get(), 9.);
|
||||
ExpectScalarEq<float>(second_components[1].get(), 9.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
@ -644,8 +671,8 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
std::array<TensorHandlePtr, 2> first_components;
|
||||
ExtractPerDeviceValues(context.get(), second_components[0].get(),
|
||||
&first_components, status.get());
|
||||
AssertScalarFloatEq(first_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(first_components[1].get(), 6.);
|
||||
ExpectScalarEq<float>(first_components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(first_components[1].get(), 6.);
|
||||
|
||||
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
|
||||
status.get());
|
||||
@ -806,8 +833,8 @@ TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 3.);
|
||||
ExpectScalarEq<float>(result_components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(result_components[1].get(), 3.);
|
||||
}
|
||||
|
||||
void RegisterCollectiveMulFunction(TFE_Context* context,
|
||||
@ -909,8 +936,8 @@ TEST(PARALLEL_DEVICE, TestFunction) {
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
|
||||
ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
|
||||
ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
|
||||
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
result_components[0].get(), status.get());
|
||||
|
@ -178,7 +178,7 @@ cc_library_with_android_deps(
|
||||
name = "ops",
|
||||
srcs = ["framework/ops.cc"],
|
||||
hdrs = ["framework/ops.h"],
|
||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
@ -197,7 +197,7 @@ cc_library_with_android_deps(
|
||||
"framework/scope_internal.h",
|
||||
],
|
||||
hdrs = ["framework/scope.h"],
|
||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||
common_deps = [
|
||||
":ops",
|
||||
],
|
||||
@ -237,7 +237,7 @@ cc_library_with_android_deps(
|
||||
name = "client_session",
|
||||
srcs = ["client/client_session.cc"],
|
||||
hdrs = ["client/client_session.h"],
|
||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||
android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
|
||||
common_deps = [
|
||||
":ops",
|
||||
":scope",
|
||||
@ -275,7 +275,7 @@ cc_library_with_android_deps(
|
||||
srcs = ["ops/const_op.cc"],
|
||||
hdrs = ["ops/const_op.h"],
|
||||
android_deps = [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
],
|
||||
common_deps = [
|
||||
":ops",
|
||||
@ -304,7 +304,7 @@ cc_library_with_android_deps(
|
||||
srcs = ["ops/while_loop.cc"],
|
||||
hdrs = ["ops/while_loop.h"],
|
||||
android_deps = [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:portable_tensorflow_lib",
|
||||
],
|
||||
common_deps = [
|
||||
":cc_ops",
|
||||
|
@ -57,7 +57,22 @@ cc_library(
|
||||
"tensor.h",
|
||||
],
|
||||
deps = [
|
||||
":status",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle",
|
||||
hdrs = [
|
||||
"tensorhandle.h",
|
||||
],
|
||||
deps = [
|
||||
":runtime",
|
||||
":status",
|
||||
":tensor",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
],
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
|
||||
@ -40,6 +41,7 @@ class Runtime {
|
||||
private:
|
||||
friend class RuntimeBuilder;
|
||||
friend class SavedModelAPI;
|
||||
friend class TensorHandle;
|
||||
|
||||
// Wraps a TFE_Context. Takes ownership of ctx.
|
||||
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
|
||||
@ -63,6 +65,7 @@ class Runtime {
|
||||
};
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
|
||||
@ -79,6 +80,7 @@ inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Status is a wrapper around an error code and an optional error message.
|
||||
@ -57,6 +58,7 @@ class Status {
|
||||
friend class RuntimeBuilder;
|
||||
friend class Runtime;
|
||||
friend class SavedModelAPI;
|
||||
friend class TensorHandle;
|
||||
|
||||
// Wraps a TF_Status*, and takes ownership of it.
|
||||
explicit Status(TF_Status* status) : status_(status) {}
|
||||
@ -88,6 +90,7 @@ inline void Status::SetStatus(TF_Code code, const std::string& msg) {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
|
||||
|
@ -19,30 +19,53 @@ limitations under the License.
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// Tensor represents an n-dimensional array of values.
|
||||
class Tensor {
|
||||
public:
|
||||
// TODO(bmzhao): Add a factory function that constructs a Tensor from a char
|
||||
// buffer, with an options struct (to specify the buffer's layout, device?,
|
||||
// whether to create a TFRT or TF tensor, whether we should take ownership of
|
||||
// the memory, etc). This requires extending TF_NewTensor with an options
|
||||
// struct:
|
||||
// https://github.com/tensorflow/tensorflow/blob/3c520614a3c056d56afdc79b59979b9b0087f8b9/tensorflow/c/tf_tensor.h#L77-L80
|
||||
using DeleterCallback = std::function<void(void*, size_t)>;
|
||||
|
||||
// Constructs a Tensor from user provided buffer.
|
||||
//
|
||||
// Params:
|
||||
// dtype - The dtype of the tensor's data.
|
||||
// shape - A shape vector, where each element corresponds to the size of
|
||||
// the tensor's corresponding dimension.
|
||||
// data - Pointer to a buffer of memory to construct a Tensor out of.
|
||||
// len - The length (in bytes) of `data`
|
||||
// deleter - A std::function to be called when the Tensor no longer needs the
|
||||
// memory in `data`. This can be used to free `data`, or
|
||||
// perhaps decrement a refcount associated with `data`, etc.
|
||||
// status - Set to OK on success and an error on failure.
|
||||
// Returns:
|
||||
// If an error occurred, status->ok() will be false, and the returned
|
||||
// Tensor must not be used.
|
||||
// TODO(bmzhao): Add Runtime as an argument to this function so we can swap to
|
||||
// a TFRT backed tensor.
|
||||
// TODO(bmzhao): Add benchmarks on overhead for this function; we can
|
||||
// consider using int64_t* + length rather than vector.
|
||||
static Tensor FromBuffer(TF_DataType dtype, const std::vector<int64_t>& shape,
|
||||
void* data, size_t len, DeleterCallback deleter,
|
||||
Status* status);
|
||||
|
||||
// TODO(bmzhao): In the case we construct a tensor from non-owned memory,
|
||||
// we should offer a way to deep copy the tensor into a new tensor, which
|
||||
// owns the underlying memory. This could be a .deepcopy()/clone() method.
|
||||
|
||||
// TODO(bmzhao): In the future, we want to relax the non-copyability
|
||||
// constraint. To do so, we can add a C API function that acts like CopyFrom:
|
||||
// constraint. To do so, we can add a C API function that acts like
|
||||
// CopyFrom:
|
||||
// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/core/framework/tensor.h#L301-L311
|
||||
|
||||
// Tensor is movable, but not copyable
|
||||
@ -85,6 +108,16 @@ class Tensor {
|
||||
// This object retains ownership of the pointer.
|
||||
TF_Tensor* GetTFTensor() const { return tensor_.get(); }
|
||||
|
||||
struct DeleterStruct {
|
||||
std::function<void(void*, size_t)> deleter;
|
||||
};
|
||||
|
||||
static void DeleterFunction(void* memory, size_t len, void* deleter_struct) {
|
||||
DeleterStruct* deleter = reinterpret_cast<DeleterStruct*>(deleter_struct);
|
||||
deleter->deleter(memory, len);
|
||||
delete deleter;
|
||||
}
|
||||
|
||||
struct TFTensorDeleter {
|
||||
void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); }
|
||||
};
|
||||
@ -111,7 +144,32 @@ inline size_t Tensor::num_bytes() const {
|
||||
return TF_TensorByteSize(tensor_.get());
|
||||
}
|
||||
|
||||
inline Tensor Tensor::FromBuffer(TF_DataType dtype,
|
||||
const std::vector<int64_t>& shape, void* data,
|
||||
size_t len, DeleterCallback deleter,
|
||||
Status* status) {
|
||||
// Credit to apassos@ for this technique:
|
||||
// Despite the fact that our API takes a std::function deleter, we are able
|
||||
// to maintain ABI stability because:
|
||||
// 1. Only a function pointer is sent across the C API (&DeleterFunction)
|
||||
// 2. DeleterFunction is defined in the same build artifact that constructed
|
||||
// the std::function (so there isn't confusion about std::function ABI).
|
||||
// Note that 2. is satisifed by the fact that this is a header-only API, where
|
||||
// the function implementations are inline.
|
||||
|
||||
DeleterStruct* deleter_struct = new DeleterStruct{deleter};
|
||||
TF_Tensor* tensor = TF_NewTensor(dtype, shape.data(), shape.size(), data, len,
|
||||
&DeleterFunction, deleter_struct);
|
||||
if (tensor == nullptr) {
|
||||
status->SetStatus(TF_INVALID_ARGUMENT,
|
||||
"Failed to create tensor for input buffer");
|
||||
return Tensor(nullptr);
|
||||
}
|
||||
return Tensor(tensor);
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
|
||||
|
98
tensorflow/cc/experimental/base/public/tensorhandle.h
Normal file
98
tensorflow/cc/experimental/base/public/tensorhandle.h
Normal file
@ -0,0 +1,98 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/status.h"
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// An opaque representation of a tensor computed/managed by the Tensorflow
|
||||
// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer
|
||||
// to tensors placed in memory of different devices or remote address spaces.
|
||||
// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created
|
||||
// from it.
|
||||
class TensorHandle {
|
||||
public:
|
||||
// Unwraps a Tensor from the given TensorHandle. If an error occurred,
|
||||
// status->ok() will be false, and the returned Tensor must not be used.
|
||||
Tensor Resolve(Status* status);
|
||||
|
||||
// Constructs a TensorHandle from a Tensor. If an error occurred,
|
||||
// status->ok() will be false, and the returned TensorHandle must not be used.
|
||||
static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime,
|
||||
Status* status);
|
||||
|
||||
// TensorHandle is movable, and not copyable
|
||||
TensorHandle(TensorHandle&&) = default;
|
||||
TensorHandle& operator=(TensorHandle&&) = default;
|
||||
|
||||
private:
|
||||
// Wraps a TFE_TensorHandle. Takes ownership of handle.
|
||||
explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {}
|
||||
|
||||
// TensorHandle is not copyable
|
||||
TensorHandle(const TensorHandle&) = delete;
|
||||
TensorHandle& operator=(const TensorHandle&) = delete;
|
||||
|
||||
// Returns the underlying TFE_TensorHandle that this object wraps.
|
||||
// This object retains ownership of the pointer.
|
||||
TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); }
|
||||
|
||||
// Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle,
|
||||
// and takes ownership of handle.
|
||||
void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); }
|
||||
|
||||
struct TFETensorHandleDeleter {
|
||||
void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
|
||||
};
|
||||
std::unique_ptr<TFE_TensorHandle, TFETensorHandleDeleter> handle_;
|
||||
};
|
||||
|
||||
inline Tensor TensorHandle::Resolve(Status* status) {
|
||||
TF_Tensor* tensor =
|
||||
TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return Tensor(nullptr);
|
||||
}
|
||||
return Tensor(tensor);
|
||||
}
|
||||
|
||||
inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor,
|
||||
const Runtime& runtime,
|
||||
Status* status) {
|
||||
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor(
|
||||
runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus());
|
||||
if (!status->ok()) {
|
||||
return TensorHandle(nullptr);
|
||||
}
|
||||
return TensorHandle(tensor_handle);
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
|
50
tensorflow/cc/experimental/base/tests/BUILD
Normal file
50
tensorflow/cc/experimental/base/tests/BUILD
Normal file
@ -0,0 +1,50 @@
|
||||
# Tests for the C++ header-only base types.
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_types_test_util",
|
||||
testonly = True,
|
||||
hdrs = ["tensor_types_test_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_datatype",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tensor_test",
|
||||
srcs = [
|
||||
"tensor_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":tensor_types_test_util",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
"//tensorflow/cc/experimental/base/public:tensor",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tensorhandle_test",
|
||||
srcs = [
|
||||
"tensorhandle_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":tensor_types_test_util",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/cc/experimental/base/public:runtime",
|
||||
"//tensorflow/cc/experimental/base/public:runtime_builder",
|
||||
"//tensorflow/cc/experimental/base/public:status",
|
||||
"//tensorflow/cc/experimental/base/public:tensor",
|
||||
"//tensorflow/cc/experimental/base/public:tensorhandle",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
163
tensorflow/cc/experimental/base/tests/tensor_test.cc
Normal file
163
tensorflow/cc/experimental/base/tests/tensor_test.cc
Normal file
@ -0,0 +1,163 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using tensorflow::experimental::cc::Status;
|
||||
using tensorflow::experimental::cc::Tensor;
|
||||
|
||||
using SimpleTypes = ::testing::Types<
|
||||
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
|
||||
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
|
||||
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
|
||||
|
||||
template <typename T>
|
||||
class ConstructScalarTensorTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes);
|
||||
|
||||
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
typename TypeParam::type value = 42;
|
||||
Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||
/*data=*/&value,
|
||||
/*len=*/sizeof(value),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 0);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
|
||||
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct1DTensorTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 1 vector.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(value.size());
|
||||
|
||||
Tensor tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 1);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct2DTensorTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 2 vector with shape 2 x 3.
|
||||
std::vector<int64_t> shape({2, 3});
|
||||
|
||||
Tensor tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 2);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
|
||||
bool done = false;
|
||||
Status status;
|
||||
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
|
||||
{
|
||||
// data_vector is a rank 1 tensor.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(data_vector.size());
|
||||
|
||||
Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
|
||||
done = true;
|
||||
};
|
||||
|
||||
Tensor tensor =
|
||||
Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
|
||||
/*data=*/data_vector.data(),
|
||||
/*len=*/data_vector.size() * sizeof(int32_t),
|
||||
/*deleter=*/callback, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
}
|
||||
// At this point, tensor has been destroyed, and the deleter callback should
|
||||
// have run.
|
||||
EXPECT_TRUE(done);
|
||||
}
|
||||
|
||||
} // namespace
|
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
||||
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Each of the following struct types have two members: a kDType that
|
||||
// corresponds to a TF_Datatype enum value, and a typedef "type"
|
||||
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
|
||||
// tests via GoogleTest's TypedTests:
|
||||
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
|
||||
struct FloatType {
|
||||
using type = float;
|
||||
static constexpr TF_DataType kDType = TF_FLOAT;
|
||||
};
|
||||
|
||||
struct DoubleType {
|
||||
using type = double;
|
||||
static constexpr TF_DataType kDType = TF_DOUBLE;
|
||||
};
|
||||
|
||||
struct Int32Type {
|
||||
using type = int32_t;
|
||||
static constexpr TF_DataType kDType = TF_INT32;
|
||||
};
|
||||
|
||||
struct UINT8Type {
|
||||
using type = uint8_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT8;
|
||||
};
|
||||
|
||||
struct INT8Type {
|
||||
using type = int8_t;
|
||||
static constexpr TF_DataType kDType = TF_INT8;
|
||||
};
|
||||
|
||||
struct INT64Type {
|
||||
using type = int64_t;
|
||||
static constexpr TF_DataType kDType = TF_INT64;
|
||||
};
|
||||
|
||||
struct UINT16Type {
|
||||
using type = uint16_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT16;
|
||||
};
|
||||
|
||||
struct UINT32Type {
|
||||
using type = uint32_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT32;
|
||||
};
|
||||
|
||||
struct UINT64Type {
|
||||
using type = uint64_t;
|
||||
static constexpr TF_DataType kDType = TF_UINT64;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
|
184
tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
Normal file
184
tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
Normal file
@ -0,0 +1,184 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/experimental/base/public/tensorhandle.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime.h"
|
||||
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
|
||||
#include "tensorflow/cc/experimental/base/public/tensor.h"
|
||||
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using tensorflow::experimental::cc::Runtime;
|
||||
using tensorflow::experimental::cc::RuntimeBuilder;
|
||||
using tensorflow::experimental::cc::Status;
|
||||
using tensorflow::experimental::cc::Tensor;
|
||||
using tensorflow::experimental::cc::TensorHandle;
|
||||
|
||||
using SimpleTypes = ::testing::Types<
|
||||
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
|
||||
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
|
||||
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
|
||||
|
||||
template <typename T>
|
||||
class ConstructScalarTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
|
||||
// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
|
||||
// verify the expected dims, dtype, value, num bytes, and num elements.
|
||||
TYPED_TEST(ConstructScalarTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
typename TypeParam::type value = 42;
|
||||
Tensor original_tensor =
|
||||
Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
|
||||
/*data=*/&value,
|
||||
/*len=*/sizeof(value),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 0);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
|
||||
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct1DTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct1DTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 1 vector.
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(value.size());
|
||||
|
||||
Tensor original_tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 1);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Construct2DTensorHandleTest : public ::testing::Test {};
|
||||
TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
|
||||
|
||||
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
|
||||
// and verifies the expected dimensions, dtype, value, number of bytes, and
|
||||
// number of elements.
|
||||
TYPED_TEST(Construct2DTensorHandleTest,
|
||||
ValidTensorAttributesAfterConstruction) {
|
||||
Status status;
|
||||
RuntimeBuilder runtime_builder;
|
||||
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TF_DataType dtype = TypeParam::kDType;
|
||||
// This is our 1D tensor of varying dtype.
|
||||
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
|
||||
// Shape is Rank 2 vector with shape 2 x 3.
|
||||
std::vector<int64_t> shape({2, 3});
|
||||
|
||||
Tensor original_tensor = Tensor::FromBuffer(
|
||||
/*dtype=*/dtype, /*shape=*/shape,
|
||||
/*data=*/value.data(),
|
||||
/*len=*/value.size() * sizeof(typename TypeParam::type),
|
||||
/*deleter=*/[](void*, size_t) {}, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
TensorHandle handle =
|
||||
TensorHandle::FromTensor(original_tensor, *runtime, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
Tensor tensor = handle.Resolve(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
EXPECT_EQ(tensor.dims(), 2);
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
|
||||
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
|
||||
EXPECT_EQ(tensor_view[0], 42);
|
||||
EXPECT_EQ(tensor_view[1], 100);
|
||||
EXPECT_EQ(tensor_view[2], 0);
|
||||
EXPECT_EQ(tensor_view[3], 1);
|
||||
EXPECT_EQ(tensor_view[4], 4);
|
||||
EXPECT_EQ(tensor_view[5], 29);
|
||||
|
||||
EXPECT_EQ(tensor.num_bytes(),
|
||||
value.size() * sizeof(typename TypeParam::type));
|
||||
EXPECT_EQ(tensor.num_elements(), value.size());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -4,7 +4,6 @@
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_android",
|
||||
"if_ios",
|
||||
"if_mobile",
|
||||
"if_not_mobile",
|
||||
"tf_cc_test",
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
|
||||
@ -54,6 +55,7 @@ inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// ConcreteFunctionList helps convert an opaque pointer to an array of
|
||||
@ -56,6 +57,7 @@ inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// FunctionMetadata stores additional function information, including
|
||||
@ -40,6 +41,7 @@ class FunctionMetadata final {
|
||||
};
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace experimental {
|
||||
namespace cc {
|
||||
|
||||
// SavedModelAPI offers a way to load Tensorflow Saved Models
|
||||
@ -155,6 +156,7 @@ inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
|
||||
}
|
||||
|
||||
} // namespace cc
|
||||
} // namespace experimental
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
|
||||
|
@ -26,10 +26,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
using tensorflow::experimental::cc::Runtime;
|
||||
using tensorflow::experimental::cc::RuntimeBuilder;
|
||||
using tensorflow::experimental::cc::SavedModelAPI;
|
||||
using tensorflow::experimental::cc::Status;
|
||||
|
||||
constexpr char kTestData[] = "cc/saved_model/testdata";
|
||||
|
||||
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||
@ -43,21 +47,21 @@ std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
|
||||
|
||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||
cc::Status status;
|
||||
cc::RuntimeBuilder builder;
|
||||
Status status;
|
||||
RuntimeBuilder builder;
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
builder.SetUseTFRT(use_tfrt);
|
||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
||||
std::unique_ptr<Runtime> runtime = builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
std::unordered_set<std::string> tags = {"serve"};
|
||||
std::unique_ptr<cc::SavedModelAPI> model =
|
||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
|
||||
std::unique_ptr<SavedModelAPI> model =
|
||||
SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
@ -67,20 +71,20 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||
}
|
||||
|
||||
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
|
||||
cc::Status status;
|
||||
cc::RuntimeBuilder builder;
|
||||
Status status;
|
||||
RuntimeBuilder builder;
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
builder.SetUseTFRT(use_tfrt);
|
||||
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
|
||||
std::unique_ptr<Runtime> runtime = builder.Build(&status);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
std::unique_ptr<cc::SavedModelAPI> model =
|
||||
cc::SavedModelAPI::Load(model_dir, *runtime, &status);
|
||||
std::unique_ptr<SavedModelAPI> model =
|
||||
SavedModelAPI::Load(model_dir, *runtime, &status);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
@ -94,4 +98,3 @@ INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -131,6 +131,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
||||
TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
|
||||
std::vector<string> dim_vars;
|
||||
string dim_sizes, indices;
|
||||
int count = 1;
|
||||
if (shape.rank() == 0 ||
|
||||
(shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
|
||||
dim_sizes = "[1]";
|
||||
@ -140,6 +141,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
||||
dim_vars.push_back(absl::StrCat("size_t dim", dim));
|
||||
dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
|
||||
indices += absl::StrCat("[dim", dim, "]");
|
||||
count *= shape.dimensions(dim);
|
||||
}
|
||||
}
|
||||
rewrites->push_back({"{{I}}", absl::StrCat(i)});
|
||||
@ -147,6 +149,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
|
||||
rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
|
||||
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
|
||||
rewrites->push_back({"{{INDICES}}", indices});
|
||||
rewrites->push_back({"{{COUNT}}", absl::StrCat(count)});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -199,6 +202,12 @@ Status GenArgMethods(const tf2xla::Config& config,
|
||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||
arg_data({{I}}))){{INDICES}};
|
||||
}
|
||||
int arg{{NAME}}_size() const {
|
||||
return {{COUNT}} * sizeof({{TYPE}});
|
||||
}
|
||||
int arg{{NAME}}_count() const {
|
||||
return {{COUNT}};
|
||||
}
|
||||
)";
|
||||
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
||||
if (!config.feed(i).name().empty()) {
|
||||
@ -246,6 +255,12 @@ Status GenResultMethods(const tf2xla::Config& config,
|
||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||
result_data({{I}}))){{INDICES}};
|
||||
}
|
||||
int result{{NAME}}_size() const {
|
||||
return {{COUNT}} * sizeof({{TYPE}});
|
||||
}
|
||||
int result{{NAME}}_count() const {
|
||||
return {{COUNT}};
|
||||
}
|
||||
)";
|
||||
*methods += RewriteWithName(absl::StrCat(i), code, rewrites);
|
||||
if (!config.fetch(i).name().empty()) {
|
||||
@ -281,6 +296,12 @@ Status GenVariableMethods(const tf2xla::Config& config,
|
||||
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
|
||||
arg_data({{I}}))){{INDICES}};
|
||||
}
|
||||
int var_{{NAME}}_size() const {
|
||||
return {{COUNT}} * sizeof({{TYPE}});
|
||||
}
|
||||
int var_{{NAME}}_count() const {
|
||||
return {{COUNT}};
|
||||
}
|
||||
)";
|
||||
const tf2xla::Variable& var = config.variable(i - config.feed_size());
|
||||
rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
|
||||
|
@ -138,6 +138,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const float(*)[1][2]>(
|
||||
arg_data(0)))[dim0][dim1];
|
||||
}
|
||||
int arg0_size() const {
|
||||
return 2 * sizeof(float);
|
||||
}
|
||||
int arg0_count() const {
|
||||
return 2;
|
||||
}
|
||||
|
||||
void set_arg_myfeed_data(const void* data) {
|
||||
set_arg_data(0, data);
|
||||
@ -156,6 +162,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const float(*)[1][2]>(
|
||||
arg_data(0)))[dim0][dim1];
|
||||
}
|
||||
int arg_myfeed_size() const {
|
||||
return 2 * sizeof(float);
|
||||
}
|
||||
int arg_myfeed_count() const {
|
||||
return 2;
|
||||
}
|
||||
|
||||
void set_arg1_data(const void* data) {
|
||||
set_arg_data(1, data);
|
||||
@ -174,6 +186,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const tensorflow::int64(*)[3][4]>(
|
||||
arg_data(1)))[dim0][dim1];
|
||||
}
|
||||
int arg1_size() const {
|
||||
return 12 * sizeof(tensorflow::int64);
|
||||
}
|
||||
int arg1_count() const {
|
||||
return 12;
|
||||
}
|
||||
|
||||
// Result methods for managing output buffers. Buffers are in row-major order.
|
||||
// Must only be called after a successful Run call. There is a set of methods
|
||||
@ -204,6 +222,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
||||
result_data(0)))[dim0][dim1];
|
||||
}
|
||||
int result0_size() const {
|
||||
return 30 * sizeof(tensorflow::uint32);
|
||||
}
|
||||
int result0_count() const {
|
||||
return 30;
|
||||
}
|
||||
|
||||
tensorflow::uint32* result_myfetch_data() {
|
||||
return static_cast<tensorflow::uint32*>(result_data(0));
|
||||
@ -219,6 +243,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
|
||||
result_data(0)))[dim0][dim1];
|
||||
}
|
||||
int result_myfetch_size() const {
|
||||
return 30 * sizeof(tensorflow::uint32);
|
||||
}
|
||||
int result_myfetch_count() const {
|
||||
return 30;
|
||||
}
|
||||
|
||||
// Methods for managing variable buffers. Buffers are in row-major order.
|
||||
//
|
||||
@ -261,6 +291,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const float(*)[1]>(
|
||||
arg_data(2)))[0];
|
||||
}
|
||||
int var_myvar_readonly_size() const {
|
||||
return 1 * sizeof(float);
|
||||
}
|
||||
int var_myvar_readonly_count() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
void set_var_myvar_data(float* data) {
|
||||
set_arg_data(3, data);
|
||||
@ -279,6 +315,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const float(*)[1]>(
|
||||
arg_data(3)))[0];
|
||||
}
|
||||
int var_myvar_size() const {
|
||||
return 1 * sizeof(float);
|
||||
}
|
||||
int var_myvar_count() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
void set_var_myvar2_data(tensorflow::int32* data) {
|
||||
set_arg_data(4, data);
|
||||
@ -297,6 +339,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
|
||||
return (*static_cast<const tensorflow::int32(*)[5]>(
|
||||
arg_data(4)))[dim0];
|
||||
}
|
||||
int var_myvar2_size() const {
|
||||
return 5 * sizeof(tensorflow::int32);
|
||||
}
|
||||
int var_myvar2_count() const {
|
||||
return 5;
|
||||
}
|
||||
|
||||
private:
|
||||
// Number of buffers for the compiled computation.
|
||||
|
@ -180,12 +180,10 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
data::MakeIteratorOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \
|
||||
data::AnonymousIteratorHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \
|
||||
data::AnonymousIteratorHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("DeleteIterator").Device(DEVICE).HostMemory("deleter"), \
|
||||
data::DeleteIteratorOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV2").Device(DEVICE), \
|
||||
data::AnonymousIteratorHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE), \
|
||||
data::DeleteIteratorOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
|
||||
data::IteratorGetNextOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \
|
||||
|
@ -31,7 +31,7 @@ filegroup(
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -695,9 +695,9 @@ cc_library(
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LoopOpsTransforms",
|
||||
"@llvm-project//mlir:MlirTranslateMain",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:SCFTransforms",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Translation",
|
||||
|
@ -1020,7 +1020,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
if (!inst->getMutableAttrDict().getAttrs().empty()) {
|
||||
os << " {";
|
||||
bool first = true;
|
||||
for (auto& named_attr : inst->getMutableAttrDict().getDictionary()) {
|
||||
for (auto& named_attr : inst->getAttrDictionary()) {
|
||||
os << (!first ? ", " : "");
|
||||
first = false;
|
||||
named_attr.first.print(os);
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
|
||||
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
|
||||
|
||||
@ -247,7 +247,14 @@ class TFL_TFTypesWithSameBits<int i, int j, int num> :
|
||||
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
|
||||
|
||||
class TFL_OperandIsNoneOrHasRankLessThanOrEqualTo<int n, int m> :
|
||||
class TFL_TFOperandTypesWithSameBits<int i, int j, int num> :
|
||||
And<[
|
||||
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||
CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isUnsignedInteger(" # num # ")">]>,
|
||||
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
|
||||
|
||||
class TFL_OperandIsNoneOrHasRankAtMost<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
||||
Or<[
|
||||
CPred<"$_op.getOperand(" # n # ").getType().isa<NoneType>()">,
|
||||
@ -255,13 +262,13 @@ class TFL_OperandIsNoneOrHasRankLessThanOrEqualTo<int n, int m> :
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
||||
|
||||
class TFL_OperandHasRankLessThanOrEqualTo<int n, int m> :
|
||||
class TFL_OperandHasRankAtMost<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
||||
|
||||
class TFL_OperandHasRankGreaterThanOrEqualTo<int n, int m> :
|
||||
class TFL_OperandHasRankAtLeast<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is at least " # m # "-D",
|
||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
@ -300,6 +307,18 @@ class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
|
||||
"quant::QuantizedType::castToStorageType("
|
||||
"getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>;
|
||||
|
||||
// This is a quantization-aware version of TCresVTEtIsSameAsOp
|
||||
class TFL_TCopVTEtAreSameAt<int i, int j> : Or<[
|
||||
TCopVTEtAreSameAt<[i, j]>,
|
||||
TFL_TFOperandTypesWithSameBits<i, j, 8>,
|
||||
And<[
|
||||
SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # j # "))",
|
||||
quant_QuantizedType.predicate>,
|
||||
CPred<"quant::QuantizedType::castToStorageType("
|
||||
"getElementTypeOrSelf($_op.getOperand(" # i # "))) == "
|
||||
"quant::QuantizedType::castToStorageType("
|
||||
"getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op common constraints.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -963,7 +982,11 @@ def TFL_ScatterNdOp : TFL_Op<"scatter_nd", [
|
||||
|
||||
// Same type check of lhs and rhs is handled by the ResultsBroadcastableShape trait.
|
||||
def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Less_equal operator";
|
||||
|
||||
let description = [{
|
||||
@ -971,8 +994,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$rhs);
|
||||
ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8]>:$rhs);
|
||||
|
||||
let results = (outs TFL_BoolTensor:$output);
|
||||
|
||||
@ -985,9 +1008,12 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization",
|
||||
[NoSideEffect]> {
|
||||
let summary = "Local Response Normalization.";
|
||||
def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization", [
|
||||
TFL_OperandHasRank<0, 4>,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
NoSideEffect]> {
|
||||
let summary = "Local Response Normalization.";
|
||||
|
||||
let description = [{
|
||||
The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last
|
||||
@ -1004,7 +1030,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, QI8, QUI8]>:$input,
|
||||
TFL_FpTensor:$input,
|
||||
I32Attr:$radius,
|
||||
F32Attr:$bias,
|
||||
F32Attr:$alpha,
|
||||
@ -1012,7 +1038,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, QI8, QUI8]>:$output
|
||||
TFL_FpTensor:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -1048,7 +1074,7 @@ def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [
|
||||
NoSideEffect,
|
||||
TFL_OperandHasAtleastRank<0, 1>,
|
||||
PredOpTrait<"operand and result must have the same element type",
|
||||
TCresVTEtIsSameAsOp<0, 0>>]> {
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
|
||||
let summary = [{
|
||||
Returns a tensor with the provided diagonal and everything else padded with zeros.
|
||||
}];
|
||||
@ -1061,17 +1087,21 @@ def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$diagonal
|
||||
TFL_TensorOf<[F32, I8, I16, I32, I64, TFL_Uint8, QUI8, QI8, TFL_Quint8]>:$diagonal
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output
|
||||
TFL_TensorOf<[F32, I8, I16, I32, I64, TFL_Uint8, QUI8, QI8, TFL_Quint8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
def TFL_MatrixSetDiagOp : TFL_Op<"matrix_set_diag", [NoSideEffect]> {
|
||||
def TFL_MatrixSetDiagOp : TFL_Op<"matrix_set_diag", [
|
||||
TFL_OperandHasAtleastRank<0, 2>,
|
||||
PredOpTrait<"input and result must have the same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect]> {
|
||||
let summary = [{
|
||||
Returns a batched matrix tensor with new batched diagonal values.
|
||||
}];
|
||||
@ -1083,12 +1113,12 @@ innermost matrices. These will be overwritten by the values in `diagonal`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$input,
|
||||
TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$diagonal
|
||||
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$input,
|
||||
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$diagonal
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$output
|
||||
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$result
|
||||
);
|
||||
|
||||
let hasOptions = 0;
|
||||
@ -1206,7 +1236,12 @@ larger than 0.
|
||||
}
|
||||
|
||||
def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||
ResultsBroadcastableShape, Commutative, NoSideEffect, NoQuantizableResult]> {
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
ResultsBroadcastableShape,
|
||||
Commutative,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Not_equal operator";
|
||||
|
||||
let description = [{
|
||||
@ -1214,8 +1249,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$lhs,
|
||||
AnyTensor:$rhs);
|
||||
ins TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$lhs,
|
||||
TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$rhs);
|
||||
|
||||
let results = (outs TFL_BoolTensor:$output);
|
||||
|
||||
@ -1284,7 +1319,7 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
|
||||
PredOpTrait<"value and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 1>>,
|
||||
TFL_OperandHasRank<0, 1>,
|
||||
TFL_OperandHasRankGreaterThanOrEqualTo<1, 2>
|
||||
TFL_OperandHasRankAtLeast<1, 2>
|
||||
]> {
|
||||
let summary = "Embedding lookup operator";
|
||||
|
||||
@ -1502,7 +1537,11 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [
|
||||
}
|
||||
|
||||
def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Greater operator";
|
||||
|
||||
let description = [{
|
||||
@ -1510,10 +1549,10 @@ def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$lhs,
|
||||
AnyTensor:$rhs);
|
||||
ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$rhs);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_BoolTensor:$output);
|
||||
|
||||
let builders = [TFL_ComparisonBinaryBuilder];
|
||||
|
||||
@ -1523,8 +1562,9 @@ def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
}
|
||||
|
||||
def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
TFL_GpuTargetOp]> {
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Hardswish activation function.";
|
||||
let description = [{
|
||||
Computes hard-swish activation function
|
||||
@ -1563,29 +1603,34 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect,
|
||||
let customOption = "L2NormOptions";
|
||||
}
|
||||
|
||||
def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
|
||||
SameOperandsAndResultShape,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Leaky Relu operator";
|
||||
|
||||
// TODO(jpienaar): Add type restriction. This op is only defined for
|
||||
// restricted (floating point) types.
|
||||
let description = [{
|
||||
Element-wise Leaky ReLU operator
|
||||
x -> x >= 0 ? x : (alpha * x)
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input,
|
||||
// Slope of the activation function at x < 0.
|
||||
F32Attr:$alpha
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 0b1;
|
||||
}
|
||||
|
||||
def TFL_LessOp : TFL_Op<"less", [
|
||||
ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> {
|
||||
ResultsBroadcastableShape,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Less operator";
|
||||
|
||||
let description = [{
|
||||
@ -1593,8 +1638,8 @@ def TFL_LessOp : TFL_Op<"less", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$lhs,
|
||||
AnyTensor:$rhs);
|
||||
ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$rhs);
|
||||
|
||||
let results = (outs TFL_BoolTensor:$output);
|
||||
|
||||
@ -1655,6 +1700,8 @@ def TFL_LogicalOrOp : TFL_Op<"logical_or", [NoSideEffect]> {
|
||||
|
||||
def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||
NoSideEffect,
|
||||
PredOpTrait<"x and y must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
SameOperandsAndResultShape,
|
||||
// zero_point = 0
|
||||
// scale = 1. / (max_value + 1)
|
||||
@ -1667,9 +1714,9 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
|
||||
Computes element-wise Sigmoid of input
|
||||
}];
|
||||
|
||||
let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$x);
|
||||
let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$x);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$y);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$y);
|
||||
}
|
||||
|
||||
def TFL_LogOp: TFL_Op<"log", [
|
||||
@ -1690,10 +1737,10 @@ def TFL_LogOp: TFL_Op<"log", [
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// TODO(b/130643170): Adds some constraint for the input/output element types.
|
||||
def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType,
|
||||
// zero_point = max_value
|
||||
// scale = -log_softmax_output_min / (max_value + 1)
|
||||
FixedResultScale<Int8UniformQuantizedType<127, 625, -4>>,
|
||||
@ -1706,9 +1753,9 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
|
||||
input - log(reduce_sum(exp(input), dim))
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$input);
|
||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
@ -1727,6 +1774,9 @@ def MaxPoolOperandAndResultConstraints : PredOpTrait<"MaxPool2D operand and "
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>]>>;
|
||||
|
||||
def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
||||
TFL_OperandHasRank<0, 4>,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
MaxPoolOperandAndResultConstraints,
|
||||
SameOperandsAndResultsScale,
|
||||
@ -1741,7 +1791,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
ins TFL_TensorOf<[F32, QUI8, QI8, QI16, TFL_Quint8]>:$input,
|
||||
TFL_PaddingAttr:$padding,
|
||||
I32Attr:$stride_w,
|
||||
I32Attr:$stride_h,
|
||||
@ -1750,7 +1800,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
|
||||
TFL_AFAttr:$fused_activation_function
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QI16, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
@ -1782,7 +1832,11 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_GpuTargetOp]> {
|
||||
def TFL_MeanOp : TFL_Op<"mean", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Mean operator";
|
||||
|
||||
let description = [{
|
||||
@ -1794,13 +1848,13 @@ def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_GpuTargetOp]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8, TFL_Uint8]>:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$input,
|
||||
TFL_TensorOf<[I32, I64]>:$axis,
|
||||
BoolAttr:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$output);
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
let customOption = "ReducerOptions";
|
||||
@ -1821,14 +1875,14 @@ def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[I32, I64]>:$indices,
|
||||
TFL_I32Tensor:$depth,
|
||||
TFL_TensorOf<[F32, I32, I64, I1]>:$on_value,
|
||||
TFL_TensorOf<[F32, I32, I64, I1]>:$off_value,
|
||||
TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$on_value,
|
||||
TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$off_value,
|
||||
|
||||
I32Attr:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, I1]>:$output
|
||||
TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -2032,7 +2086,11 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
def TFL_PackOp : TFL_Op<"pack", [
|
||||
PredOpTrait<"values and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Packs a list of tensors along a dimension into one tensor";
|
||||
|
||||
let description = [{
|
||||
@ -2063,14 +2121,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$values,
|
||||
TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$values,
|
||||
|
||||
I32Attr:$values_count,
|
||||
Confined<I32Attr, [IntPositive]>:$values_count,
|
||||
I32Attr:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$output
|
||||
TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$output
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
@ -2081,8 +2139,11 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
}
|
||||
|
||||
def TFL_PadOp : TFL_Op<"pad", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRankAtMost<0, 4>,
|
||||
TFL_OperandHasRank<1, 2>,
|
||||
TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
||||
TFL_GpuTargetOp]> {
|
||||
@ -2113,22 +2174,25 @@ def TFL_PadOp : TFL_Op<"pad", [
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
|
||||
let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_I32OrI64Tensor:$padding);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_PadV2Op : TFL_Op<"padv2", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRankAtMost<0, 4>,
|
||||
TFL_OperandHasRank<1, 2>,
|
||||
TFL_OperandHasRank<2, 0>,
|
||||
TFL_OperandRankEquals1DimOfOperand<0, 1>,
|
||||
PredOpTrait<"input and constant value operands must have same element type",
|
||||
TCopVTEtAreSameAt<[0, 2]>>]> {
|
||||
TFL_TCopVTEtAreSameAt<0, 2>>]> {
|
||||
let summary = "Padding operator v2";
|
||||
|
||||
let description = [{
|
||||
@ -2159,11 +2223,11 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
|
||||
ins TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_I32OrI64Tensor:$padding,
|
||||
TFL_TensorOf<[F32, I8, I32, I64]>:$constant_values);
|
||||
TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$constant_values);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
@ -2191,9 +2255,21 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape,
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
}
|
||||
|
||||
def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect,
|
||||
TFL_GpuTargetOp,
|
||||
SameOperandsAndResultsScale]> {
|
||||
def TFL_PReluOp : TFL_Op<"prelu", [
|
||||
NoSideEffect,
|
||||
ResultsBroadcastableShape,
|
||||
TFL_GpuTargetOp,
|
||||
TFL_OperandHasRankAtMost<0, 4>,
|
||||
TFL_OperandHasRankAtMost<1, 4>,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
PredOpTrait<"input and output must have the same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
PredOpTrait<"'alpha' should have one less rank than 'input'.",
|
||||
Or<[TFL_OperandIsUnrankedPred<0>,
|
||||
TFL_OperandIsUnrankedPred<1>,
|
||||
CPred<"$_op.getOperand(0).getType().cast<ShapedType>().getRank() == "
|
||||
"$_op.getOperand(1).getType().cast<ShapedType>().getRank() "
|
||||
"+ 1">]>>]> {
|
||||
let summary = "Parameterized Relu operator";
|
||||
|
||||
let description = [{
|
||||
@ -2206,11 +2282,11 @@ def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect,
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, QUI8]>:$input,
|
||||
TFL_TensorOf<[F32, QUI8]>:$alpha
|
||||
ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$alpha
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
@ -2887,7 +2963,7 @@ def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [
|
||||
SameOperandsAndResultsScale,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_OperandHasRankLessThanOrEqualTo<0, 4>
|
||||
TFL_OperandHasRankAtMost<0, 4>
|
||||
]> {
|
||||
let summary = "DepthToSpace operator";
|
||||
|
||||
@ -3224,7 +3300,7 @@ def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
|
||||
ElementsAttr:$value
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[QUI8, QI8, QI16, QUI16, TFL_Quint8]>:$output);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &, OperationState &state, TypeAttr qtype, Attribute value",
|
||||
@ -3849,7 +3925,7 @@ def TFL_NumericVerifyOp : Op<TFL_Dialect, "NumericVerify", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[QI8, QUI8, QI16, QUI16]>:$input,
|
||||
TFL_TensorOf<[QI8, QUI8, QI16, F16, TFL_Quint8]>:$input,
|
||||
TFL_TensorOf<[F32]>:$ref,
|
||||
|
||||
// Attributes
|
||||
|
@ -146,6 +146,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
|
||||
saved_model_exported_names.begin(), saved_model_exported_names.end());
|
||||
absl::Span<std::string> exported_names(exported_names_in_vector);
|
||||
|
||||
if (exported_names.size() != 1) {
|
||||
return errors::Unimplemented("Only support a single exported name.");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
ImportSavedModel(model_flags.saved_model_dir(),
|
||||
model_flags.saved_model_version(), tags,
|
||||
|
@ -573,7 +573,7 @@ func @testLogistic(tensor<1x2x3x4x5xf32>) -> tensor<1x2x3x4x5xf32> {
|
||||
// test invalid Logistic input
|
||||
func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
|
||||
^bb0(%arg0: tensor<?xi32>):
|
||||
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or QUI16 type values}}
|
||||
// expected-error @+1 {{'tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or TFLite quint8 type values, but got 'tensor<?xi32>'}}
|
||||
%0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0#0 : tensor<?xi32>
|
||||
}
|
||||
@ -1252,10 +1252,10 @@ func @testOneHot(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<f32>, %
|
||||
|
||||
// -----
|
||||
|
||||
func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<*xi8> {
|
||||
// expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer values}}
|
||||
%0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<*xi8>
|
||||
return %0 : tensor<*xi8>
|
||||
func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<*xi16> {
|
||||
// expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer or 8-bit signless integer or 8-bit unsigned integer values, but got 'tensor<*xi16>'}}
|
||||
%0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<*xi16>
|
||||
return %0 : tensor<*xi16>
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -1489,7 +1489,8 @@ func @testEmbeddingLookupValueAndResultElementTypeTraitFailed(%arg0 : tensor<?xi
|
||||
|
||||
// -----
|
||||
|
||||
func @testQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>) -> tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>> {
|
||||
func @testWrongQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>) -> tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>> {
|
||||
// expected-error @+1 {{'tfl.local_response_normalization' op operand #0 must be tensor of 32-bit float values, but got 'tensor<1x56x56x192x!quant.uniform<u8:f32, 2.000000e-02>>'}}
|
||||
%0 = "tfl.local_response_normalization"(%arg0) {alpha = 9.99999974E-5 : f32, beta = 5.000000e-01 : f32, bias = 2.000000e+00 : f32, radius = 5 : i32} : (tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>) -> tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>
|
||||
return %0 : tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>
|
||||
}
|
||||
@ -1523,32 +1524,32 @@ func @testDepthToSpaceInvalidOutputType(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluWrongOutputRank(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<1x1x10xf32>) -> tensor<10x10x10xf32> {
|
||||
// expected-error @+1 {{'input' and 'output' should have the same rank}}
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<10x10x10x10xf32>, tensor<1x1x10xf32>) -> tensor<10x10x10xf32>
|
||||
return %0 : tensor<10x10x10xf32>
|
||||
func @testPReluWrongOutputRank(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<10x10x10x10xf32>) -> tensor<10x10xf32> {
|
||||
// expected-error @+1 {{'tfl.prelu' op result type '10x10' not broadcast compatible with broadcasted operands's shapes '10x10x10x10'}}
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<10x10x10x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32>
|
||||
return %0 : tensor<10x10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluWrongOutputShape(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32> {
|
||||
// expected-error @+1 {{'input' and 'output' should have the same shape}}
|
||||
// expected-error @+1 {{'tfl.prelu' op result type '1x2x3x5' not broadcast compatible with broadcasted operands's shapes '1x2x3x4'}}
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32>
|
||||
return %0 : tensor<1x2x3x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32> {
|
||||
func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<7x3x2x14xf32>) -> tensor<7x3x2x14xf32> {
|
||||
// expected-error @+1 {{'alpha' should have one less rank than 'input'.}}
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x3x2x14xf32>, tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32>
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x3x2x14xf32>, tensor<7x3x2x14xf32>) -> tensor<7x3x2x14xf32>
|
||||
return %0 : tensor<7x3x2x14xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testPReluInvalidBroadcast(%arg0: tensor<15x14x2x14xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32> {
|
||||
// expected-error @+1 {{'alpha' is not broadcastable at dimension 2.}}
|
||||
// expected-error @+1 {{'tfl.prelu' op operands don't have broadcast-compatible shapes}}
|
||||
%0 = "tfl.prelu"(%arg0, %arg1) : (tensor<15x14x2x14xf32>, tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32>
|
||||
return %0 : tensor<15x14x2x14xf32>
|
||||
}
|
||||
|
@ -160,6 +160,11 @@ int main(int argc, char **argv) {
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
absl::Span<std::string> exported_names(exported_names_vector);
|
||||
|
||||
if (exported_names.size() != 1) {
|
||||
llvm::errs() << "There should be only one exported name";
|
||||
return kTrFailure;
|
||||
}
|
||||
|
||||
module = tensorflow::ImportSavedModel(input_file_name, saved_model_version,
|
||||
tags, exported_names, &context);
|
||||
} else {
|
||||
|
@ -174,7 +174,7 @@ StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
return module;
|
||||
} else if (saved_model_version == 1) {
|
||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
input_filename, tags, context);
|
||||
input_filename, tags, exported_names, context);
|
||||
|
||||
if (!module)
|
||||
return tensorflow::errors::InvalidArgument("fail to open input file");
|
||||
|
@ -12,6 +12,22 @@ cc_library(
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
# (yongtang) The graph_optimization_pass_registration needs to be part
|
||||
# of a shared object that will be loaded whenever `import tensorflow`
|
||||
# is run. The natural place is libtensorflow_framework.so.
|
||||
# While adding graph_optimization_pass_registration to
|
||||
# libtensorflow_framework.so is possible with some modification in
|
||||
# dependency, many tests will fail due to multiple copies of LLVM.
|
||||
# See https://github.com/tensorflow/tensorflow/pull/39231 for details.
|
||||
# Alternatively, we place graph_optimization_pass_registration here
|
||||
# because:
|
||||
# - tensorflow/python/_pywrap_mlir.so already depends on LLVM anyway
|
||||
# - tensorflow/python/_pywrap_mlir.so always loaded as part of python
|
||||
# binding
|
||||
# TODO: It might be still preferrable to place graph_optimization_pass
|
||||
# as part of the libtensorflow_framework.so, as it is the central
|
||||
# place for core related components.
|
||||
"//tensorflow/compiler/mlir/tensorflow:graph_optimization_pass_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:import_utils",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
@ -112,7 +112,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
|
||||
// Convert the SavedModelBundle to an MLIR module.
|
||||
|
||||
mlir::MLIRContext context;
|
||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
|
||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, {}, &context);
|
||||
if (!module_or.status().ok()) {
|
||||
Set_TF_Status_from_Status(status, module_or.status());
|
||||
return "// error";
|
||||
|
41
tensorflow/compiler/mlir/python/mlir_wrapper/BUILD
Normal file
41
tensorflow/compiler/mlir/python/mlir_wrapper/BUILD
Normal file
@ -0,0 +1,41 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "mlir_wrapper",
|
||||
srcs = [
|
||||
"attrs.cc",
|
||||
"basic_classes.cc",
|
||||
"builders.cc",
|
||||
"mlir_wrapper.cc",
|
||||
"mlir_wrapper.h",
|
||||
"ops.cc",
|
||||
"types.cc",
|
||||
],
|
||||
module_name = "mlir_wrapper",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//tensorflow/python:pybind11_status",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "filecheck_wrapper",
|
||||
srcs = ["filecheck_wrapper.cc"],
|
||||
module_name = "filecheck_wrapper",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//tensorflow/python:pybind11_status",
|
||||
"@llvm-project//llvm:support",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
25
tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc
Normal file
25
tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc
Normal file
@ -0,0 +1,25 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
|
||||
void init_attrs(py::module& m) {
|
||||
py::class_<mlir::Attribute>(m, "Attribute");
|
||||
py::class_<mlir::IntegerAttr, mlir::Attribute>(m, "IntegerAttr")
|
||||
.def("get",
|
||||
py::overload_cast<mlir::Type, int64_t>(&mlir::IntegerAttr::get));
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/FileCheck.h"
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Region.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
|
||||
void init_basic_classes(py::module& m) {
|
||||
py::class_<mlir::MLIRContext>(m, "MLIRContext").def(py::init<>());
|
||||
|
||||
py::class_<mlir::Location>(m, "Location");
|
||||
|
||||
py::class_<mlir::UnknownLoc>(m, "UnknownLoc")
|
||||
.def("get", &mlir::UnknownLoc::get);
|
||||
|
||||
py::class_<mlir::Region>(m, "Region")
|
||||
.def("back", &mlir::Region::back, py::return_value_policy::reference)
|
||||
.def("front", &mlir::Region::front, py::return_value_policy::reference)
|
||||
.def("add_block", [](mlir::Region& r) { r.push_back(new mlir::Block); })
|
||||
.def("push_back", &mlir::Region::push_back)
|
||||
.def("size", [](mlir::Region& r) { return r.getBlocks().size(); })
|
||||
.def("front", &mlir::Region::front, py::return_value_policy::reference);
|
||||
py::class_<mlir::Block::iterator>(m, "Block_Iterator");
|
||||
py::class_<mlir::Block>(m, "Block")
|
||||
.def("new", ([]() { return new mlir::Block; }),
|
||||
py::return_value_policy::reference)
|
||||
.def("end", &mlir::Block::end)
|
||||
.def("addArgument", &mlir::Block::addArgument);
|
||||
|
||||
py::class_<mlir::Value>(m, "Value").def("getType", &mlir::Value::getType);
|
||||
py::class_<mlir::OpResult, mlir::Value>(m, "OpResult");
|
||||
py::class_<mlir::BlockArgument, mlir::Value>(m, "BlockArgument");
|
||||
}
|
51
tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc
Normal file
51
tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
|
||||
void init_builders(py::module& m) {
|
||||
py::class_<mlir::Builder>(m, "Builder")
|
||||
.def(py::init<mlir::MLIRContext*>())
|
||||
.def("getFunctionType",
|
||||
[](mlir::Builder& b, std::vector<mlir::Type> inputs,
|
||||
std::vector<mlir::Type> outputs) {
|
||||
return b.getFunctionType(llvm::ArrayRef<mlir::Type>(inputs),
|
||||
llvm::ArrayRef<mlir::Type>(outputs));
|
||||
});
|
||||
py::class_<mlir::OpBuilder>(m, "OpBuilder")
|
||||
.def(py::init<mlir::MLIRContext*>())
|
||||
.def(py::init<mlir::Region&>())
|
||||
.def(py::init<mlir::Operation*>())
|
||||
.def(py::init<mlir::Block*, mlir::Block::iterator>())
|
||||
.def("getUnknownLoc", &mlir::OpBuilder::getUnknownLoc)
|
||||
.def("setInsertionPoint",
|
||||
py::overload_cast<mlir::Block*, mlir::Block::iterator>(
|
||||
&mlir::OpBuilder::setInsertionPoint))
|
||||
.def("saveInsertionPoint", &mlir::OpBuilder::saveInsertionPoint)
|
||||
.def("restoreInsertionPoint", &mlir::OpBuilder::restoreInsertionPoint)
|
||||
.def(
|
||||
"createOperation",
|
||||
[](mlir::OpBuilder& opb, mlir::OperationState& state) {
|
||||
return opb.createOperation(state);
|
||||
},
|
||||
py::return_value_policy::reference)
|
||||
.def("getContext", &mlir::OpBuilder::getContext,
|
||||
py::return_value_policy::reference);
|
||||
|
||||
py::class_<mlir::OpBuilder::InsertPoint>(m, "OpBuilder_InsertionPoint")
|
||||
.def("getBlock", &mlir::OpBuilder::InsertPoint::getBlock);
|
||||
}
|
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/FileCheck.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
PYBIND11_MODULE(filecheck_wrapper, m) {
|
||||
m.def("check", [](std::string input, std::string check) {
|
||||
llvm::FileCheckRequest fcr;
|
||||
llvm::FileCheck fc(fcr);
|
||||
llvm::SourceMgr SM = llvm::SourceMgr();
|
||||
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
|
||||
llvm::SMLoc());
|
||||
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(check),
|
||||
llvm::SMLoc());
|
||||
llvm::Regex regex = fc.buildCheckPrefixRegex();
|
||||
fc.readCheckFile(SM, llvm::StringRef(check), regex);
|
||||
return fc.checkInput(SM, llvm::StringRef(input));
|
||||
});
|
||||
}
|
38
tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
Normal file
38
tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
PYBIND11_MODULE(mlir_wrapper, m) {
|
||||
m.def("registerDialects", []() {
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
});
|
||||
|
||||
init_basic_classes(m);
|
||||
init_types(m);
|
||||
init_builders(m);
|
||||
init_ops(m);
|
||||
init_attrs(m);
|
||||
}
|
30
tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h
Normal file
30
tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
|
||||
#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void init_basic_classes(py::module& m);
|
||||
void init_types(py::module& m);
|
||||
void init_builders(py::module& m);
|
||||
void init_ops(py::module& m);
|
||||
void init_attrs(py::module& m);
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
|
194
tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc
Normal file
194
tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc
Normal file
@ -0,0 +1,194 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
void init_ops(py::module& m) {
|
||||
py::class_<mlir::Operation, std::unique_ptr<mlir::Operation, py::nodelete>>(
|
||||
m, "Operation")
|
||||
.def("getRegion", &mlir::Operation::getRegion,
|
||||
py::return_value_policy::reference)
|
||||
.def("getResult", &mlir::Operation::getResult)
|
||||
.def("dump", &mlir::Operation::dump)
|
||||
.def("getNumResults", &mlir::Operation::getNumResults);
|
||||
|
||||
py::class_<mlir::OperationState>(m, "OperationState")
|
||||
.def(py::init([](mlir::Location loc, std::string name) {
|
||||
return mlir::OperationState(loc, llvm::StringRef(name));
|
||||
}))
|
||||
.def("addTypes",
|
||||
[](mlir::OperationState& state, std::vector<mlir::Type> tys) {
|
||||
state.addTypes(mlir::ArrayRef<mlir::Type>(tys));
|
||||
})
|
||||
.def("addOperands",
|
||||
[](mlir::OperationState& os, std::vector<mlir::Value> ops) {
|
||||
os.addOperands(mlir::ArrayRef<mlir::Value>(ops));
|
||||
})
|
||||
.def("addRegion", py::overload_cast<>(&mlir::OperationState::addRegion),
|
||||
py::return_value_policy::reference);
|
||||
|
||||
py::class_<mlir::ModuleOp>(m, "ModuleOp")
|
||||
.def("create",
|
||||
[](mlir::Location loc) { return mlir::ModuleOp::create(loc); })
|
||||
.def("push_back",
|
||||
[](mlir::ModuleOp& m, mlir::FuncOp f) { m.push_back(f); })
|
||||
.def("dump", &mlir::ModuleOp::dump)
|
||||
.def("getAsStr", [](mlir::ModuleOp& m) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
m.print(os);
|
||||
return os.str();
|
||||
});
|
||||
|
||||
py::class_<mlir::FuncOp>(m, "FuncOp")
|
||||
.def("create",
|
||||
[](mlir::Location location, std::string name,
|
||||
mlir::FunctionType type) {
|
||||
auto func = mlir::FuncOp::create(location, name, type);
|
||||
func.addEntryBlock();
|
||||
return func;
|
||||
})
|
||||
.def(
|
||||
"getBody",
|
||||
[](mlir::FuncOp& f) -> mlir::Region& { return f.getBody(); },
|
||||
py::return_value_policy::reference)
|
||||
.def("getArguments",
|
||||
[](mlir::FuncOp& f) { return f.getArguments().vec(); })
|
||||
.def("getName", [](mlir::FuncOp& f) { return f.getName().str(); })
|
||||
.def("getType", &mlir::FuncOp::getType);
|
||||
|
||||
py::class_<mlir::ReturnOp>(m, "ReturnOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||
std::vector<mlir::Value> values) -> mlir::Operation* {
|
||||
return opb
|
||||
.create<mlir::ReturnOp>(loc,
|
||||
mlir::ArrayRef<mlir::Value>(values))
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::AddOp
|
||||
py::class_<mlir::TF::AddV2Op>(m, "Tf_AddV2Op")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::AddV2Op>(loc, x, y).getOperation();
|
||||
});
|
||||
|
||||
py::class_<mlir::TF::AnyOp>(m, "Tf_AnyOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value input,
|
||||
mlir::Value reduction_indices,
|
||||
bool keep_dims = false) -> mlir::Operation* {
|
||||
return opb
|
||||
.create<mlir::TF::AnyOp>(loc, opb.getI1Type(), input,
|
||||
reduction_indices, keep_dims)
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::ConstOp
|
||||
py::class_<mlir::TF::ConstOp>(m, "Tf_ConstOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||
mlir::Attribute value) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::ConstOp>(loc, value).getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::EqualOp
|
||||
py::class_<mlir::TF::EqualOp>(m, "Tf_EqualOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb
|
||||
.create<mlir::TF::EqualOp>(loc, x, y, opb.getBoolAttr(true))
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::GreaterEqualOp
|
||||
py::class_<mlir::TF::GreaterEqualOp>(m, "Tf_GreaterEqualOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::GreaterEqualOp>(loc, x, y)
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::GreaterOp
|
||||
py::class_<mlir::TF::GreaterOp>(m, "Tf_GreaterOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::GreaterOp>(loc, x, y).getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::LegacyCallOp
|
||||
py::class_<mlir::TF::LegacyCallOp>(m, "Tf_LegacyCallOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||
std::vector<mlir::Type> output, std::vector<mlir::Value> args,
|
||||
std::string f) -> mlir::Operation* {
|
||||
return opb
|
||||
.create<mlir::TF::LegacyCallOp>(
|
||||
loc, mlir::ArrayRef<mlir::Type>(output),
|
||||
mlir::ArrayRef<mlir::Value>(args), mlir::StringRef(f))
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::LessEqualOp
|
||||
py::class_<mlir::TF::LessEqualOp>(m, "Tf_LessEqualOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::LessEqualOp>(loc, x, y).getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::LessOp
|
||||
py::class_<mlir::TF::LessOp>(m, "Tf_LessOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::LessOp>(loc, x, y).getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::NegOp
|
||||
py::class_<mlir::TF::NegOp>(m, "Tf_NegOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc,
|
||||
mlir::Value x) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::NegOp>(loc, x).getOperation();
|
||||
});
|
||||
|
||||
py::class_<mlir::TF::NotEqualOp>(m, "Tf_NotEqualOp")
|
||||
.def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) {
|
||||
return opb
|
||||
.create<mlir::TF::NotEqualOp>(
|
||||
loc, x, y, mlir::BoolAttr::get(true, opb.getContext()))
|
||||
.getOperation();
|
||||
});
|
||||
|
||||
// mlir::TF::SubOp
|
||||
py::class_<mlir::TF::SubOp>(m, "Tf_SubOp")
|
||||
.def("create",
|
||||
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
|
||||
mlir::Value y) -> mlir::Operation* {
|
||||
return opb.create<mlir::TF::SubOp>(loc, x, y).getOperation();
|
||||
});
|
||||
}
|
48
tensorflow/compiler/mlir/python/mlir_wrapper/types.cc
Normal file
48
tensorflow/compiler/mlir/python/mlir_wrapper/types.cc
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
void init_types(py::module& m) {
|
||||
// Type
|
||||
py::class_<mlir::Type> Type(m, "Type");
|
||||
Type.def("getKind", &mlir::Type::getKind);
|
||||
|
||||
// Type Enums
|
||||
py::enum_<mlir::StandardTypes::Kind>(Type, "StandardTypes_Kind")
|
||||
.value("BF16", mlir::StandardTypes::BF16);
|
||||
|
||||
// Type Sub-classes
|
||||
py::class_<mlir::FunctionType, mlir::Type>(m, "FunctionType")
|
||||
.def("getResults",
|
||||
[](mlir::FunctionType& ft) { return ft.getResults().vec(); });
|
||||
|
||||
py::class_<mlir::FloatType, mlir::Type>(m, "FloatType")
|
||||
.def("get", &mlir::FloatType::get);
|
||||
|
||||
py::class_<mlir::IntegerType, mlir::Type>(m, "IntegerType")
|
||||
.def("get", py::overload_cast<unsigned, mlir::MLIRContext*>(
|
||||
&mlir::IntegerType::get));
|
||||
|
||||
py::class_<mlir::UnrankedTensorType, mlir::Type>(m, "UnrankedTensorType")
|
||||
.def("get", &mlir::UnrankedTensorType::get);
|
||||
|
||||
py::class_<mlir::RankedTensorType, mlir::Type>(m, "RankedTensorType")
|
||||
.def("get", [](std::vector<int64_t> shape, mlir::Type ty) {
|
||||
return mlir::RankedTensorType::get(mlir::ArrayRef<int64_t>(shape), ty);
|
||||
});
|
||||
}
|
@ -70,9 +70,9 @@ tool_dirs = config.mlir_tf_tools_dirs + [
|
||||
]
|
||||
tool_names = [
|
||||
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
||||
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
|
||||
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt',
|
||||
'xla-opt'
|
||||
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
||||
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
||||
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt'
|
||||
]
|
||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
@ -44,6 +44,7 @@ mlir_tf_tools_dirs = [
|
||||
'tensorflow/compiler/mlir',
|
||||
'tensorflow/compiler/mlir/lite',
|
||||
'tensorflow/compiler/mlir/tensorflow',
|
||||
'tensorflow/compiler/mlir/tfjs',
|
||||
'tensorflow/compiler/mlir/xla',
|
||||
'tensorflow/compiler/aot',
|
||||
'tensorflow/compiler/xla/service/mlir_gpu',
|
||||
|
@ -36,7 +36,7 @@ filegroup(
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -556,7 +556,7 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LoopOpsTransforms",
|
||||
"@llvm-project//mlir:SCFTransforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -823,6 +823,7 @@ cc_library(
|
||||
":mangling_util",
|
||||
":tensorflow_attributes",
|
||||
":tensorflow_types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -1074,7 +1075,7 @@ genrule(
|
||||
srcs = [
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||
"ir/tf_generated_ops.td",
|
||||
"ir/tf_op_base.td",
|
||||
|
@ -192,6 +192,44 @@ retained with length 1.
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect]> {
|
||||
let summary = "An Op to exchange data across TPU replicas.";
|
||||
|
||||
let description = [{
|
||||
On each replica, the input is split into `split_count` blocks along
|
||||
`split_dimension` and send to the other replicas given group_assignment. After
|
||||
receiving `split_count` - 1 blocks from other replicas, we concatenate the
|
||||
blocks along `concat_dimension` as the output.
|
||||
|
||||
For example, suppose there are 2 TPU replicas:
|
||||
replica 0 receives input: `[[A, B]]`
|
||||
replica 1 receives input: `[[C, D]]`
|
||||
|
||||
group_assignment=`[[0, 1]]`
|
||||
concat_dimension=0
|
||||
split_dimension=1
|
||||
split_count=2
|
||||
|
||||
replica 0's output: `[[A], [C]]`
|
||||
replica 1's output: `[[B], [D]]`
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
|
||||
I32Tensor:$group_assignment,
|
||||
|
||||
I64Attr:$concat_dimension,
|
||||
I64Attr:$split_dimension,
|
||||
I64Attr:$split_count
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let summary = "Returns the argument of a complex number.";
|
||||
|
||||
@ -1217,7 +1255,7 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect]> {
|
||||
let summary = "Clips tensor values to a specified min and max.";
|
||||
|
||||
let description = [{
|
||||
@ -1408,6 +1446,30 @@ tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_ConjugateTransposeOp : TF_Op<"ConjugateTranspose", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Shuffle dimensions of x according to a permutation and conjugate the result.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
|
||||
`y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
|
||||
`y[i,j,k,...,s,t,u] == conj(x[perm[i], perm[j], perm[k],...,perm[s], perm[t], perm[u]])`
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$x,
|
||||
TF_I32OrI64Tensor:$perm
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tperm = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect, TF_LayoutSensitiveInterface]> {
|
||||
let summary = [{
|
||||
Computes a 2-D convolution given 4-D `input` and `filter` tensors.
|
||||
@ -1682,7 +1744,28 @@ Given an input tensor, this function computes hyperbolic cosine of every
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [AllTypesMatch<["input", "output"]>, NoSideEffect]> {
|
||||
def TF_CrossOp : TF_Op<"Cross", [NoSideEffect]> {
|
||||
let summary = "Compute the pairwise cross product.";
|
||||
|
||||
let description = [{
|
||||
`a` and `b` must be the same shape; they can either be simple 3-element vectors,
|
||||
or any shape where the innermost dimension is 3. In the latter case, each pair
|
||||
of corresponding 3-element vectors is cross-multiplied independently.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_IntOrFpTensor:$a,
|
||||
TF_IntOrFpTensor:$b
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_IntOrFpTensor:$product
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> {
|
||||
let summary = "An Op to sum inputs across replicated TPU instances.";
|
||||
|
||||
let description = [{
|
||||
@ -1706,7 +1789,7 @@ and `B, D, F, H` as group 1. Thus we get the outputs:
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CumsumOp : TF_Op<"Cumsum", [AllTypesMatch<["x", "out"]>, NoSideEffect]> {
|
||||
def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
|
||||
let summary = "Compute the cumulative sum of the tensor `x` along `axis`.";
|
||||
|
||||
let description = [{
|
||||
@ -3256,8 +3339,8 @@ Gather slices from `params` axis `axis` according to `indices`.
|
||||
|
||||
let description = [{
|
||||
`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).
|
||||
Produces an output tensor with shape `params.shape[:axis] + indices.shape +
|
||||
params.shape[axis + 1:]` where:
|
||||
Produces an output tensor with shape `params.shape[:axis] +
|
||||
indices.shape[batch_dims:] + params.shape[axis + 1:]` where:
|
||||
|
||||
```python
|
||||
# Scalar indices (output is rank(params) - 1).
|
||||
@ -3542,6 +3625,31 @@ tf.imag(input) ==> [4.75, 5.75]
|
||||
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Create a copy of `x` with the updated specified rows 'i' with values 'v'.
|
||||
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Creates a copy of tensor 'x' and updates the columns specified in tensor 'i'
|
||||
with the values 'v'. Originally this function was mutative however for
|
||||
compilation we make this operation create / operate on a copy.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$x,
|
||||
I32Tensor:$i,
|
||||
TF_Tensor:$v
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_InvOp : TF_Op<"Inv", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes the reciprocal of x element-wise.";
|
||||
|
||||
@ -4242,7 +4350,7 @@ cublas.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [AllTypesMatch<["input", "band"]>, NoSideEffect]> {
|
||||
def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, TF_AllTypesMatch<["input", "band"]>]> {
|
||||
let summary = [{
|
||||
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
||||
}];
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#define TF_OP_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -70,6 +70,16 @@ class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
|
||||
"$_op.getOperand(" # opId # ").getType(), "
|
||||
"$_op.getResult(" # resId # ").getType())">]>;
|
||||
|
||||
|
||||
class TF_AllTypesMatchPred<list<string> values> :
|
||||
CPred<"TF::AreCastCompatible(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">;
|
||||
|
||||
class TF_AllTypesMatch<list<string> names> :
|
||||
PredOpTrait<
|
||||
"all of {" # StrJoin<names>.result # "} have dynamically equal types ",
|
||||
TF_AllTypesMatchPred<
|
||||
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -110,47 +110,6 @@ static inline bool HasRankAtMost(Value value, int64_t rank) {
|
||||
return !type || type.getRank() <= rank;
|
||||
}
|
||||
|
||||
// Returns true if the given pair of TensorFlow types can be cast to one
|
||||
// another. In other words, a single run-time value is legal for both the types.
|
||||
// For example, tensor<*xf32> and tensor<3xf32> are cast compatible.
|
||||
static bool AreCastCompatible(Type a, Type b) {
|
||||
if (TensorCastOp::areCastCompatible(a, b)) return true;
|
||||
|
||||
// Resource types may optionally contain subtypes information that does not
|
||||
// match. Check subtypes compatibility when possible, otherwise treat them as
|
||||
// compatible.
|
||||
auto a_or_element_type = getElementTypeOrSelf(a);
|
||||
auto b_or_element_type = getElementTypeOrSelf(b);
|
||||
|
||||
auto a_kind = a_or_element_type.getKind();
|
||||
auto b_kind = b_or_element_type.getKind();
|
||||
|
||||
if (a_kind == TensorFlowTypes::RESOURCE &&
|
||||
b_kind == TensorFlowTypes::RESOURCE) {
|
||||
auto a_resource_type = a_or_element_type.dyn_cast<ResourceType>();
|
||||
auto b_resource_type = b_or_element_type.dyn_cast<ResourceType>();
|
||||
bool a_has_subtype = !a_resource_type.getSubtypes().empty();
|
||||
bool b_has_subtype = !b_resource_type.getSubtypes().empty();
|
||||
|
||||
if (!a_has_subtype || !b_has_subtype) return true;
|
||||
|
||||
assert(a_resource_type.getSubtypes().size() <= 1 &&
|
||||
"Resource type must have at most one subtype");
|
||||
assert(b_resource_type.getSubtypes().size() <= 1 &&
|
||||
"Resource type must have at most one subtype");
|
||||
|
||||
return TensorCastOp::areCastCompatible(
|
||||
a_resource_type.getSubtypes().front(),
|
||||
b_resource_type.getSubtypes().front());
|
||||
}
|
||||
|
||||
// Variant types may optionally contain subtypes information that need not
|
||||
// match. It is also not possible to compare subtypes for compatibility as
|
||||
// their interpretation depends on the ops operating on them. So, accept all
|
||||
// pairs of variant types.
|
||||
return a_kind == TensorFlowTypes::VARIANT &&
|
||||
b_kind == TensorFlowTypes::VARIANT;
|
||||
}
|
||||
|
||||
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
|
||||
return dim_or_rank == -1;
|
||||
@ -984,20 +943,17 @@ void ConstOp::build(OpBuilder &builder, OperationState &result, Type type,
|
||||
|
||||
LogicalResult ConstOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
for (NamedAttribute named_attr : attributes) {
|
||||
if (named_attr.first.strref() != "value") continue;
|
||||
auto value = named_attr.second;
|
||||
if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
|
||||
inferredReturnTypes.assign({elem_attr.getType()});
|
||||
return success();
|
||||
}
|
||||
return emitOptionalError(location,
|
||||
"attribute 'value' failed to satisfy constraint: "
|
||||
"constant vector/tensor");
|
||||
auto value = attributes.get("value");
|
||||
if (!value) return emitOptionalError(location, "missing attribute 'value'");
|
||||
if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
|
||||
inferredReturnTypes.assign({elem_attr.getType()});
|
||||
return success();
|
||||
}
|
||||
return emitOptionalError(location, "missing attribute 'value'");
|
||||
return emitOptionalError(location,
|
||||
"attribute 'value' failed to satisfy constraint: "
|
||||
"constant vector/tensor");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1416,7 +1372,7 @@ static LogicalResult Verify(DynamicStitchOp op) {
|
||||
auto expected_out_ty =
|
||||
RankedTensorType::get(expected_shape, out_ty.getElementType());
|
||||
|
||||
if (!AreCastCompatible(out_ty, expected_out_ty)) {
|
||||
if (!AreCastCompatible({out_ty, expected_out_ty})) {
|
||||
return op.emitOpError() << "has invalid output type; should be "
|
||||
"compatible with inferred type "
|
||||
<< expected_out_ty;
|
||||
@ -1817,14 +1773,14 @@ static LogicalResult Verify(IfOp op) {
|
||||
for (unsigned i = 0; i < expectedNumInputs; ++i) {
|
||||
auto operandType = op.getOperand(i + 1).getType().cast<TensorType>();
|
||||
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
|
||||
if (!AreCastCompatible(operandType, thenInputType))
|
||||
if (!AreCastCompatible({operandType, thenInputType}))
|
||||
return op.emitError(
|
||||
llvm::formatv("then branch input type {0} is incompatible with "
|
||||
"operand type {1} at index {2}",
|
||||
thenInputType, operandType, i));
|
||||
|
||||
auto elseInputType = elseFuncType.getInput(i).cast<TensorType>();
|
||||
if (!AreCastCompatible(operandType, elseInputType))
|
||||
if (!AreCastCompatible({operandType, elseInputType}))
|
||||
return op.emitError(
|
||||
llvm::formatv("else branch input type {0} is incompatible with "
|
||||
"operand type {1} at index {2}",
|
||||
@ -1832,7 +1788,7 @@ static LogicalResult Verify(IfOp op) {
|
||||
|
||||
// If branches have incompatible input types that means that no tensor can
|
||||
// serve as input to both the functions. Hence, the op is invalid.
|
||||
if (!AreCastCompatible(thenInputType, elseInputType))
|
||||
if (!AreCastCompatible({thenInputType, elseInputType}))
|
||||
return op.emitError(llvm::formatv(
|
||||
"branches inputs have incompatible types {0} and {1} at index {2}",
|
||||
thenInputType, elseInputType, i));
|
||||
@ -1848,14 +1804,14 @@ static LogicalResult Verify(IfOp op) {
|
||||
for (unsigned i = 0; i < expectedNumResults; ++i) {
|
||||
auto resultType = op.getResult(i).getType().cast<TensorType>();
|
||||
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
|
||||
if (!AreCastCompatible(thenResultType, resultType))
|
||||
if (!AreCastCompatible({thenResultType, resultType}))
|
||||
return op.emitError(
|
||||
llvm::formatv("then branch result type {0} is incompatible with op "
|
||||
"result type {1} at index {2}",
|
||||
thenResultType, resultType, i));
|
||||
|
||||
auto elseResultType = elseFuncType.getResult(i).cast<TensorType>();
|
||||
if (!AreCastCompatible(elseResultType, resultType))
|
||||
if (!AreCastCompatible({elseResultType, resultType}))
|
||||
return op.emitError(
|
||||
llvm::formatv("else branch result type {0} is incompatible with op "
|
||||
"result type {1} at index {2}",
|
||||
@ -3792,7 +3748,7 @@ static LogicalResult Verify(WhileOp op) {
|
||||
auto aType = a.second[idx];
|
||||
auto bType = b.second[idx];
|
||||
|
||||
if (!AreCastCompatible(aType, bType))
|
||||
if (!AreCastCompatible({aType, bType}))
|
||||
return op.emitError(llvm::formatv(
|
||||
"{0} type {1} is incompatible with {2} type {3} at index {4}",
|
||||
a.first, aType, b.first, bType, idx));
|
||||
|
@ -28,6 +28,134 @@ llvm::Optional<llvm::ArrayRef<int64_t>> GetShape(mlir::Value value) {
|
||||
if (shaped_type.hasRank()) return shaped_type.getShape();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
// Merges cast compatible shapes and returns a more refined shape. The two
|
||||
// shapes are cast compatible if they have the same rank and at each dimension,
|
||||
// either both have same size or one of them is dynamic. Returns false if the
|
||||
// given shapes are not cast compatible. The refined shape is same or more
|
||||
// precise than the two input shapes.
|
||||
bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,
|
||||
llvm::ArrayRef<int64_t> b_shape,
|
||||
llvm::SmallVectorImpl<int64_t>* refined_shape) {
|
||||
if (a_shape.size() != b_shape.size()) return false;
|
||||
int64_t rank = a_shape.size();
|
||||
refined_shape->reserve(rank);
|
||||
for (auto dims : llvm::zip(a_shape, b_shape)) {
|
||||
int64_t dim1 = std::get<0>(dims);
|
||||
int64_t dim2 = std::get<1>(dims);
|
||||
|
||||
if (mlir::ShapedType::isDynamic(dim1)) {
|
||||
refined_shape->push_back(dim2);
|
||||
continue;
|
||||
}
|
||||
if (mlir::ShapedType::isDynamic(dim2)) {
|
||||
refined_shape->push_back(dim1);
|
||||
continue;
|
||||
}
|
||||
if (dim1 == dim2) {
|
||||
refined_shape->push_back(dim1);
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Given two types `a` and `b`, returns a refined type which is cast compatible
|
||||
// with both `a` and `b` and is equal to or more precise than both of them. It
|
||||
// returns empty Type if the input types are not cast compatible.
|
||||
//
|
||||
// The two types are considered cast compatible if they have dynamically equal
|
||||
// shapes and element type. For element types that do not have subtypes, they
|
||||
// must be equal. However for TensorFlow types such as Resource and Variant,
|
||||
// that also have subtypes, we recursively check for subtype compatibilty for
|
||||
// Resource types and assume all variant types are cast compatible. If either
|
||||
// one of `a` or `b` have empty subtypes, they are considered cast compatible.
|
||||
//
|
||||
// The returned type is same or more precise than the input types. For example,
|
||||
// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
|
||||
// tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
|
||||
//
|
||||
// Provides option to ignore ref types on 'a'. This is useful for TF ops that
|
||||
// might allow operands to either be same as result type or be a ref type
|
||||
// corresponding to it.
|
||||
mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
|
||||
bool may_ignore_ref_type_a) {
|
||||
// Fast path if everything is equal.
|
||||
if (a == b) return b;
|
||||
|
||||
auto a_tt = a.dyn_cast<mlir::TensorType>();
|
||||
auto b_tt = b.dyn_cast<mlir::TensorType>();
|
||||
|
||||
// If only one of a or b is a tensor type, they are incompatible.
|
||||
if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
|
||||
|
||||
// For non-tensor types, we do not need to worry about shape and can return
|
||||
// early.
|
||||
if (!a_tt && !b_tt) {
|
||||
// Remove ref types.
|
||||
if (may_ignore_ref_type_a) {
|
||||
if (auto ref_type = a.dyn_cast<mlir::TF::TensorFlowRefType>()) {
|
||||
a = ref_type.RemoveRef();
|
||||
if (a == b) return a;
|
||||
}
|
||||
}
|
||||
if (a.getKind() != b.getKind()) return nullptr;
|
||||
|
||||
// If either is not a type that contain subtypes then the types are not cast
|
||||
// compatible.
|
||||
auto a_wst = a.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
|
||||
auto b_wst = b.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
|
||||
if (!a_wst || !b_wst) return nullptr;
|
||||
|
||||
// For Variant types we are more permissive right now and accept all pairs
|
||||
// of Variant types. If we are more constrainted and check compatibility of
|
||||
// subtypes, we might reject valid graphs.
|
||||
// TODO(prakalps): Variant doesn't have a subtype, we assign it
|
||||
// one, so we should only assign it one when we know the subtype. Then we
|
||||
// can be more constrained and check subtypes for cast compatibility as
|
||||
// well.
|
||||
if (a.isa<mlir::TF::VariantType>()) return a;
|
||||
|
||||
// For Resource types, we recursively check the subtypes for cast
|
||||
// compatibility, if possible. Otherwise treat them as compatible.
|
||||
auto a_wst_st = a_wst.GetSubtypes();
|
||||
auto b_wst_st = b_wst.GetSubtypes();
|
||||
if (a_wst_st.empty() || b_wst_st.empty()) return a;
|
||||
if (a_wst_st.size() != b_wst_st.size()) return nullptr;
|
||||
llvm::SmallVector<mlir::TensorType, 4> refined_subtypes;
|
||||
for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) {
|
||||
mlir::Type refined_st =
|
||||
GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
|
||||
/*may_ignore_ref_type_a=*/false);
|
||||
if (!refined_st) return nullptr;
|
||||
refined_subtypes.push_back(refined_st.cast<mlir::TensorType>());
|
||||
}
|
||||
|
||||
return mlir::TF::ResourceType::get(refined_subtypes, a.getContext());
|
||||
}
|
||||
|
||||
// For tensor types, check compatibility of both element type and shape.
|
||||
mlir::Type refined_element_ty = GetCastCompatibleType(
|
||||
a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
|
||||
if (!refined_element_ty) return nullptr;
|
||||
|
||||
if (!a_tt.hasRank() && !b_tt.hasRank()) {
|
||||
return mlir::UnrankedTensorType::get(refined_element_ty);
|
||||
}
|
||||
if (!a_tt.hasRank()) {
|
||||
return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty);
|
||||
}
|
||||
if (!b_tt.hasRank()) {
|
||||
return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty);
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t, 8> refined_shape;
|
||||
if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
|
||||
return nullptr;
|
||||
|
||||
return mlir::RankedTensorType::get(refined_shape, refined_element_ty);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
@ -224,44 +352,16 @@ bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
|
||||
|
||||
bool HasCompatibleElementTypes(Type lhs, Type rhs,
|
||||
bool may_ignore_ref_type_lhs) {
|
||||
// Fast path if everything is equal.
|
||||
if (lhs == rhs) return true;
|
||||
return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
|
||||
}
|
||||
|
||||
// In TF all values are tensors.
|
||||
auto lhs_tt = lhs.cast<TensorType>();
|
||||
auto rhs_tt = rhs.cast<TensorType>();
|
||||
|
||||
// Verify matching element types. These should be identical dynamically,
|
||||
// so this allows for types not yet fully refined.
|
||||
auto lhs_et = lhs_tt.getElementType();
|
||||
auto rhs_et = rhs_tt.getElementType();
|
||||
if (lhs_et == rhs_et) return true;
|
||||
|
||||
// Remove ref types.
|
||||
if (may_ignore_ref_type_lhs) {
|
||||
if (auto ref_type = lhs_et.dyn_cast<TF::TensorFlowRefType>()) {
|
||||
lhs_et = ref_type.RemoveRef();
|
||||
if (lhs_et == rhs_et) return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (lhs_et.getKind() != rhs_et.getKind()) return false;
|
||||
|
||||
// If either is not type that contain subtypes then the element types don't
|
||||
// match.
|
||||
auto lhs_wst = lhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
||||
auto rhs_wst = rhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
||||
if (!lhs_wst || !rhs_wst) return false;
|
||||
|
||||
// Consider the subtype recursively.
|
||||
auto lhs_wst_st = lhs_wst.GetSubtypes();
|
||||
auto rhs_wst_st = rhs_wst.GetSubtypes();
|
||||
if (lhs_wst_st.empty() || rhs_wst_st.empty()) return true;
|
||||
if (lhs_wst_st.size() != rhs_wst_st.size()) return false;
|
||||
for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) {
|
||||
if (!HasCompatibleElementTypes(std::get<0>(subtypes),
|
||||
std::get<1>(subtypes)))
|
||||
return false;
|
||||
bool AreCastCompatible(ArrayRef<Type> types) {
|
||||
Type common = types.front();
|
||||
for (auto type : types.drop_front()) {
|
||||
Type refined_type =
|
||||
GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
|
||||
if (!refined_type) return false;
|
||||
common = refined_type;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -313,6 +313,12 @@ bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs);
|
||||
bool HasCompatibleElementTypes(Type lhs, Type rhs,
|
||||
bool may_ignore_ref_type_lhs = false);
|
||||
|
||||
// Returns true if all TensorFlow types can be cast to one
|
||||
// another. In other words, a single run-time value is legal for both the types.
|
||||
// For example, tensor<*xf32>, tensor<?xf32> and tensor<3xf32> are cast
|
||||
// compatible.
|
||||
bool AreCastCompatible(ArrayRef<Type> types);
|
||||
|
||||
} // end namespace TF
|
||||
} // end namespace mlir
|
||||
|
||||
|
@ -881,20 +881,29 @@ func @testValidMatrixBandPartOpUnranked(%arg0: tensor<*xbf16>, %arg1: tensor<i64
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.MatrixBandPart
|
||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
|
||||
// expected-error @+1 {{op failed to verify that all of {input, band} have same type}}
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
|
||||
return %0 : tensor<64x64xbf16>
|
||||
// Test valid tf.MatrixBandPart
|
||||
// CHECK-LABEL: func @testValidMatrixBandPartOpUnrankedBand
|
||||
func @testValidMatrixBandPartOpUnrankedBand(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
|
||||
return %0 : tensor<*xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.MatrixBandPart
|
||||
// CHECK-LABEL: func @testValidMatrixBandPartOpCompatibleDynamicShapes
|
||||
func @testValidMatrixBandPartOpCompatibleDynamicShapes(%arg0: tensor<?x10x?xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<?x?x8xbf16> {
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<?x10x?xbf16>, tensor<i64>, tensor<i64>) -> tensor<?x?x8xbf16>
|
||||
return %0 : tensor<?x?x8xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.MatrixBandPart
|
||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
|
||||
// expected-error @+1 {{op failed to verify that all of {input, band} have same type}}
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
|
||||
return %0 : tensor<*xbf16>
|
||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
|
||||
// expected-error @+1 {{op failed to verify that all of {input, band} have dynamically equal types}}
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
|
||||
return %0 : tensor<64x64xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -1,13 +1,17 @@
|
||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-extract-head-tail-outside-compilation | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with no input or output dependecies.
|
||||
// Tests extraction of a outside compiled ops at head of TPU computation.
|
||||
|
||||
// CHECK-LABEL: func @nodep_single_head_outside_compilation
|
||||
func @nodep_single_head_outside_compilation() -> () {
|
||||
// CHECK: "tf.A"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
"tf_device.launch"() ( {
|
||||
"tf.A"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
func @single_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK: tf_device.launch
|
||||
// CHECK: "tf.A"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
|
||||
"tf.B"() : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
@ -15,15 +19,62 @@ func @nodep_single_head_outside_compilation() -> () {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @nodep_multiple_head_outside_compilation
|
||||
func @nodep_multiple_head_outside_compilation() -> () {
|
||||
// CHECK: "tf.A"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
"tf_device.launch"() ( {
|
||||
"tf.A"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
// CHECK-LABEL: func @multiple_head_outside_compilation
|
||||
func @multiple_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
|
||||
// CHECK: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return %[[B_OUT]]
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.D"(%[[LAUNCH_OUT]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||
%1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||
"tf.C"(%1, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> ()
|
||||
"tf.D"(%1) : (tensor<i32>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_do_not_outside_compiled_ops_in_middle
|
||||
func @test_do_not_outside_compiled_ops_in_middle(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK-NOT: tf_device.launch
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.A"(%arg0) {} : (tensor<i32>) -> (tensor<i32>)
|
||||
%1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
|
||||
"tf.C"(%1) : (tensor<i32>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_ops_with_tpu_operands_not_extracted
|
||||
func @test_ops_with_tpu_operands_not_extracted(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
|
||||
// CHECK-NEXT: tf_device.return %[[D_OUT]]
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.B"
|
||||
// CHECK: "tf.C"
|
||||
// CHECK: "tf.E"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||
%1 = "tf.B"() {} : () -> (tensor<i32>)
|
||||
%2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
|
||||
%3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
|
||||
%4 = "tf.E"(%3) {} : (tensor<i32>) -> (tensor<i32>)
|
||||
tf_device.return
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
return
|
||||
|
@ -3,12 +3,12 @@
|
||||
// Tests that missing `_xla_outside_compilation` attribute value results in an error.
|
||||
|
||||
func @missing_outside_compilation_attribute() -> () {
|
||||
"tf_device.launch"() ( {
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
// expected-error@+1 {{attribute '_xla_outside_compilation' is empty}}
|
||||
"tf.B"() {_xla_outside_compilation = ""} : () -> ()
|
||||
tf_device.return
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -18,11 +18,11 @@ func @missing_outside_compilation_attribute() -> () {
|
||||
|
||||
// CHECK-LABEL: func @no_outside_compilation
|
||||
func @no_outside_compilation() -> tensor<?xi32> {
|
||||
%0 = "tf_device.launch"() ( {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.A"() : () -> tensor<?xi32>
|
||||
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
@ -36,16 +36,15 @@ func @nodep_single_outside_compilation() -> () {
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK: device = "tpu0"
|
||||
// CHECK-SAME: launch_attr = "launch_attr"
|
||||
"tf_device.launch"() ( {
|
||||
// CHECK: cluster_attr = "cluster_attr"
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -59,19 +58,18 @@ func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: "tf.D"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: "tf.E"
|
||||
// CHECK: device = "tpu0"
|
||||
// CHECK-SAME: launch_attr = "launch_attr"
|
||||
"tf_device.launch"() ( {
|
||||
// CHECK: cluster_attr = "cluster_attr"
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.D"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.E"() : () -> ()
|
||||
tf_device.return
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -80,15 +78,16 @@ func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
|
||||
// CHECK-LABEL: func @nodep_multiple_outside_compilation
|
||||
func @nodep_multiple_outside_compilation() -> () {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-COUNT-3: "tf_device.launch"
|
||||
"tf_device.launch"() ( {
|
||||
// CHECK-COUNT-2: "tf_device.launch"
|
||||
// CHECK: "tf_device.cluster"
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
"tf.D"() {_xla_outside_compilation = "cluster2"} : () -> ()
|
||||
"tf.E"() : () -> ()
|
||||
tf_device.return
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -100,17 +99,17 @@ func @single_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tens
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[TPU_LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: tf_device.return %[[TPU_LAUNCH_OUTPUT]]
|
||||
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.launch"() ( {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%3 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %3 : tensor<?xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
@ -125,17 +124,17 @@ func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> te
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[TPU_LAUNCH_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]]:2 = "tf_device.cluster"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: tf_device.return %[[TPU_LAUNCH_OUTPUT]]
|
||||
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
||||
%1:4 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2, %3 = "tf_device.launch"() ( {
|
||||
%2, %3 = "tf_device.cluster"() ( {
|
||||
%4 = "tf.A"() : () -> tensor<?xf32>
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%5 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4, %5 : tensor<?xf32>, tensor<?xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> (tensor<?xf32>, tensor<?xi32>)
|
||||
}) {cluster_attr = "cluster_attr"} : () -> (tensor<?xf32>, tensor<?xi32>)
|
||||
tf_device.return %2, %3 : tensor<?xf32>, tensor<?xi32>
|
||||
}
|
||||
|
||||
|
@ -1222,6 +1222,41 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
// -----
|
||||
|
||||
// Tests simple case of `tf_device.cluster_func` on TPU with replication and parallel_execute.
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
|
||||
// CHECK-LABEL: func @replicated_parallel_tpu_cluster_func
|
||||
func @replicated_parallel_tpu_cluster_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
// CHECK: "tf._TPUCompileMlir"
|
||||
// CHECK: "tf.TPUCompileSucceededAssert"
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK: "tf.TPUExecute"
|
||||
%3 = "tf_device.parallel_execute"() ( {
|
||||
"tf.D"() : () -> ()
|
||||
tf_device.return
|
||||
}, {
|
||||
%4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) : () -> (tensor<?xi32>)
|
||||
tf_device.return %3 : tensor<?xi32>
|
||||
}
|
||||
%2 = "tf.C"(%1#1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests devices are set properly for non replicated model parallelism.
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
|
||||
|
@ -258,7 +258,7 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass();
|
||||
|
||||
// Creates a pass that extracts outside compilation (CPU ops inside TPU cluster)
|
||||
// at head/tail of TPU cluster to run before/after TPU computation.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateTPUExtractHeadTailOutsideCompilationPass();
|
||||
|
||||
// Creates a pass that extract outside compilation (CPU ops inside TPU cluster)
|
||||
|
@ -66,8 +66,7 @@ using tensorflow::shape_inference::ShapeHandle;
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType(
|
||||
FuncOp func) {
|
||||
Optional<SmallVector<Type, 4>> InferShapeForFunctionReturnType(FuncOp func) {
|
||||
// Find any return ops.
|
||||
SmallVector<ReturnOp, 4> return_ops;
|
||||
for (Block& block : func) {
|
||||
@ -137,9 +136,9 @@ void AddCastBackForUnsupportedNonTFUses(Operation* op, Value result,
|
||||
cast_op = b.create<TF::CastOp>(op->getLoc(), old_type, result,
|
||||
/*truncate=*/b.getBoolAttr(false));
|
||||
}
|
||||
return mlir::Value(cast_op);
|
||||
return Value(cast_op);
|
||||
};
|
||||
for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) {
|
||||
for (OpOperand& use : make_early_inc_range(result.getUses())) {
|
||||
if (use.getOwner()->getDialect() != tf_dialect &&
|
||||
!IsSupportedNonTFOp(use.getOwner()))
|
||||
use.set(get_cast_op());
|
||||
@ -162,7 +161,7 @@ Optional<tensorflow::PartialTensorShape> GetShapeFromMlirType(Type t) {
|
||||
bool InferShapeForPassThroughOps(OperandRange pass_through_operands,
|
||||
Operation* op, Dialect* tf_dialect) {
|
||||
bool changed = false;
|
||||
for (auto entry : llvm::zip(pass_through_operands, op->getResults())) {
|
||||
for (auto entry : zip(pass_through_operands, op->getResults())) {
|
||||
Type operand_type = std::get<0>(entry).getType();
|
||||
Value result = std::get<1>(entry);
|
||||
if (result.getType() == operand_type) continue;
|
||||
@ -204,7 +203,7 @@ bool InferShapeForNonTFDialectOperation(Operation* op, Dialect* tf_dialect) {
|
||||
tf_dialect);
|
||||
}
|
||||
// TODO(b/155227679): Use OpInterface instead of hard-coding for TensorCastOp.
|
||||
if (auto tensor_cast = dyn_cast<mlir::TensorCastOp>(op)) {
|
||||
if (auto tensor_cast = dyn_cast<TensorCastOp>(op)) {
|
||||
return InferShapeForPassThroughOps(
|
||||
tensor_cast.getOperation()->getOperands(), op, tf_dialect);
|
||||
}
|
||||
@ -254,7 +253,7 @@ GetSubtypes(Type type) {
|
||||
// match the i-th operand type). Returns true if anything is changed.
|
||||
bool PassThroughOperandTypes(OperandRange operands, ResultRange results) {
|
||||
bool changed = false;
|
||||
for (auto entry : llvm::zip(operands, results)) {
|
||||
for (auto entry : zip(operands, results)) {
|
||||
Type operand_type = std::get<0>(entry).getType();
|
||||
Type result_type = std::get<1>(entry).getType();
|
||||
if (operand_type == result_type) continue;
|
||||
@ -291,14 +290,13 @@ bool InferShapeForCall(Operation* op) {
|
||||
CallInterfaceCallable callable = call_op.getCallableForCallee();
|
||||
SymbolRefAttr sym = callable.dyn_cast<SymbolRefAttr>();
|
||||
if (!sym) return false;
|
||||
FuncOp func =
|
||||
dyn_cast<mlir::FuncOp>(SymbolTable::lookupNearestSymbolFrom(op, sym));
|
||||
FuncOp func = dyn_cast<FuncOp>(SymbolTable::lookupNearestSymbolFrom(op, sym));
|
||||
if (!func) return false;
|
||||
|
||||
bool changed = false;
|
||||
// Map each of the results of the call to the returned type of the
|
||||
// function.
|
||||
for (auto result : llvm::zip(op->getResults(), func.getType().getResults())) {
|
||||
for (auto result : zip(op->getResults(), func.getType().getResults())) {
|
||||
if (std::get<0>(result).getType() == std::get<1>(result)) continue;
|
||||
// Skip already statically shaped results.
|
||||
if (!CanBeRefined(std::get<0>(result).getType())) continue;
|
||||
@ -323,8 +321,8 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti,
|
||||
Operation* op = infer_ti.getOperation();
|
||||
SmallVector<Type, 4> inferred;
|
||||
LogicalResult res = infer_ti.inferReturnTypes(
|
||||
op->getContext(), op->getLoc(), op->getOperands(), op->getAttrs(),
|
||||
op->getRegions(), inferred);
|
||||
op->getContext(), op->getLoc(), op->getOperands(),
|
||||
op->getAttrDictionary(), op->getRegions(), inferred);
|
||||
if (failed(res)) {
|
||||
op->emitOpError("failed to refine type as inference failed");
|
||||
return false;
|
||||
@ -335,7 +333,7 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti,
|
||||
// Map each of the results of the call to the returned type of the
|
||||
// function.
|
||||
bool changed = false;
|
||||
for (auto result : llvm::zip(op->getResults(), inferred)) {
|
||||
for (auto result : zip(op->getResults(), inferred)) {
|
||||
if (std::get<0>(result).getType() == std::get<1>(result)) continue;
|
||||
|
||||
// Inserts a cast back to the original type if any user is not in the
|
||||
@ -356,7 +354,7 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti,
|
||||
// so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output
|
||||
// scalar value).
|
||||
struct ValuePort {
|
||||
llvm::PointerUnion<Operation*, BlockArgument> producer;
|
||||
PointerUnion<Operation*, BlockArgument> producer;
|
||||
SmallVector<unsigned int, 2> port;
|
||||
|
||||
bool operator==(const ValuePort& other) const {
|
||||
@ -374,39 +372,38 @@ struct ValuePort {
|
||||
port = {0};
|
||||
}
|
||||
}
|
||||
ValuePort(llvm::PointerUnion<Operation*, BlockArgument> producer,
|
||||
ValuePort(PointerUnion<Operation*, BlockArgument> producer,
|
||||
SmallVector<unsigned int, 2> port)
|
||||
: producer(producer), port(port) {}
|
||||
|
||||
llvm::raw_ostream& print(llvm::raw_ostream& os) const {
|
||||
raw_ostream& print(raw_ostream& os) const {
|
||||
if (auto* op = producer.dyn_cast<Operation*>())
|
||||
os << "op " << op->getName();
|
||||
if (auto ba = producer.dyn_cast<BlockArgument>())
|
||||
os << "block_arg " << ba.getArgNumber();
|
||||
os << llvm::formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
|
||||
os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct ValuePortHasher {
|
||||
std::size_t operator()(const ValuePort& other) const {
|
||||
return llvm::hash_combine(
|
||||
llvm::hash_value(other.producer.getOpaqueValue()),
|
||||
llvm::hash_value(ArrayRef<unsigned int>(other.port)));
|
||||
return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()),
|
||||
hash_value(ArrayRef<unsigned int>(other.port)));
|
||||
}
|
||||
};
|
||||
|
||||
using ValuePortResultMap =
|
||||
std::unordered_map<ValuePort, Attribute, ValuePortHasher>;
|
||||
using ComputedQueryFn = llvm::function_ref<bool(ValuePort)>;
|
||||
using ValueQueryFn = llvm::function_ref<Attribute(const ValuePort&)>;
|
||||
using ValuePortInputs = llvm::SmallVectorImpl<ValuePort>;
|
||||
using ComputedQueryFn = function_ref<bool(ValuePort)>;
|
||||
using ValueQueryFn = function_ref<Attribute(const ValuePort&)>;
|
||||
using ValuePortInputs = SmallVectorImpl<ValuePort>;
|
||||
|
||||
// TODO(jpienaar): InputsRequiredForOutput and ComputeOutputComponent are
|
||||
// TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are
|
||||
// intended to be switched to op interfaces once more refined.
|
||||
LogicalResult InputsRequiredForOutput(ValuePort value_port,
|
||||
ComputedQueryFn has_been_computed,
|
||||
ValuePortInputs* inputs) {
|
||||
LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
|
||||
ComputedQueryFn has_been_computed,
|
||||
ValuePortInputs* inputs) {
|
||||
auto op = value_port.producer.dyn_cast<Operation*>();
|
||||
auto& port = value_port.port;
|
||||
if (!op) return failure();
|
||||
@ -460,26 +457,94 @@ Attribute ComputeOutputComponent(const ValuePort& value_port,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
|
||||
// Context used during ShapeInference. This class contains common information
|
||||
// that is required by the individual shape inference helper functions (e.g.,
|
||||
// TF Graph version, constant values computed, etc.)
|
||||
class ShapeInference {
|
||||
public:
|
||||
ShapeInference(int64_t graph_version, MLIRContext* context);
|
||||
|
||||
LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
|
||||
ValuePortInputs* inputs) {
|
||||
return ::mlir::TF::ComputeInputsRequiredForOutput(
|
||||
value_port,
|
||||
[this](const ValuePort& port) {
|
||||
return results_.find(port) != results_.end();
|
||||
},
|
||||
inputs);
|
||||
}
|
||||
|
||||
Attribute ComputeOutputComponent(const ValuePort& value_port) {
|
||||
return ::mlir::TF::ComputeOutputComponent(
|
||||
value_port, [this](const ValuePort& port) { return results_[port]; });
|
||||
}
|
||||
|
||||
// Returns ShapeHandle if the op result could be computed as shape.
|
||||
ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic);
|
||||
|
||||
void RecordValue(const ValuePort& value_port, Attribute value) {
|
||||
results_[value_port] = value;
|
||||
}
|
||||
|
||||
// Performs shape inference on the provided op and return true if the type of
|
||||
// at least one result has been changed.
|
||||
// A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
|
||||
// `graph_version` indicates the current GraphDef compatibility versions
|
||||
// (the versions field in graph.proto).
|
||||
bool InferShapeForSingleOperation(Operation* op);
|
||||
|
||||
// Infers shape on the provided region, including nested ones, iterate until
|
||||
// fix point with a limit of max_iteration. Returns success if fix point is
|
||||
// reached before max_iteration.
|
||||
LogicalResult InferShapeUntilFixPoint(Region* region,
|
||||
int64_t max_iteration = 10);
|
||||
|
||||
// Updates input types and refine shapes inside body of functions that are
|
||||
// attached to ControlFlow ops (If/While). These functions include Then/Else
|
||||
// branches of IfOp and Cond/Body functions of WhileOp. These functions share
|
||||
// following common properties:
|
||||
// 1) They are never reused, ie. having a single use in module.
|
||||
// 2) Their input types match those of their parent ops (excluding inputs
|
||||
// like predicate).
|
||||
// Returns a boolean indicating whether any change has been applied.
|
||||
LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
|
||||
ArrayRef<Type> input_types,
|
||||
int64_t max_iteration);
|
||||
|
||||
// Propagate the shapes to the functions named.
|
||||
LogicalResult PropagateShapeToFunctions(
|
||||
ModuleOp module, Operation::operand_type_range input_types,
|
||||
ArrayRef<StringRef> func_names, int64_t max_iteration);
|
||||
|
||||
// Shape propagation for call/control flow ops.
|
||||
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
|
||||
int64_t max_iteration);
|
||||
|
||||
private:
|
||||
// Mapping between ValuePort (which corresponds to an OpResult or smaller,
|
||||
// e.g., first element of OpResult produded) to an Attribute if the ValuePort
|
||||
// corresponds to a constant value.
|
||||
ValuePortResultMap results_;
|
||||
int64_t graph_version_;
|
||||
MLIRContext* context_;
|
||||
Dialect* tf_dialect_;
|
||||
};
|
||||
|
||||
ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context)
|
||||
: graph_version_(graph_version) {
|
||||
context_ = context;
|
||||
tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>();
|
||||
}
|
||||
|
||||
ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
|
||||
InferenceContext* ic) {
|
||||
LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially "));
|
||||
auto rt = result.getType().dyn_cast<RankedTensorType>();
|
||||
if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {};
|
||||
int dim_size = rt.getDimSize(0);
|
||||
|
||||
// Worklist to direct partial evaluation.
|
||||
llvm::SmallVector<ValuePort, 4> worklist;
|
||||
// The ValuePort evaluated results.
|
||||
// TODO(jpienaar): This could be cached across invocations (e.g., part of some
|
||||
// inference context).
|
||||
ValuePortResultMap evaluated;
|
||||
// Returns whether a ValuePort has been previously computed.
|
||||
auto has_been_computed = [&evaluated](const ValuePort& port) {
|
||||
return evaluated.find(port) != evaluated.end();
|
||||
};
|
||||
// Returns previously computed ValuePort value.
|
||||
auto values = [&evaluated](const ValuePort& port) -> Attribute {
|
||||
return evaluated[port];
|
||||
};
|
||||
SmallVector<ValuePort, 4> worklist;
|
||||
|
||||
// Simple evaluator that attempts to partially evaluate the input value even
|
||||
// if unable to evaluate the complete output. Below follows a simple stack
|
||||
@ -498,7 +563,7 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
|
||||
LLVM_DEBUG(front.print(llvm::errs() << "\nWorklist front "));
|
||||
|
||||
SmallVector<ValuePort, 4> inputs;
|
||||
auto res = InputsRequiredForOutput(front, has_been_computed, &inputs);
|
||||
auto res = ComputeInputsRequiredForOutput(front, &inputs);
|
||||
if (failed(res)) {
|
||||
// Abort if unable to find which required inputs need to be computed.
|
||||
worklist.clear();
|
||||
@ -513,16 +578,16 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto ret = ComputeOutputComponent(front, values);
|
||||
auto ret = ComputeOutputComponent(front);
|
||||
if (!ret) continue;
|
||||
|
||||
evaluated[front] = ret;
|
||||
RecordValue(front, ret);
|
||||
LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
|
||||
|
||||
// If worklist is empty, then this is the root query op.
|
||||
if (worklist.empty()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "[root node]\n");
|
||||
if (auto dea = ret.dyn_cast<mlir::DenseIntElementsAttr>()) {
|
||||
if (auto dea = ret.dyn_cast<DenseIntElementsAttr>()) {
|
||||
if (dea.getNumElements() != 1) {
|
||||
LLVM_DEBUG(llvm::errs() << "Unexpected number of elements\n");
|
||||
return {};
|
||||
@ -536,9 +601,8 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
|
||||
return ic->MakeShape(dims);
|
||||
}
|
||||
|
||||
bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
int64_t graph_version) {
|
||||
assert(tf_dialect == op->getDialect());
|
||||
bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||
assert(tf_dialect_ == op->getDialect());
|
||||
// The shape function of these ops sometimes does not propagate subtypes
|
||||
// (handle shapes) for resource and variant types. We use a simple passthrough
|
||||
// to make sure they are preserved in the output.
|
||||
@ -550,7 +614,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
// If no result for this op needs shape inference, we have a fast-path return.
|
||||
// But if the type is a resource/variant, we do not skip it because we might
|
||||
// not have the handle shapes.
|
||||
if (llvm::none_of(op->getResultTypes(), CanBeRefined)) {
|
||||
if (none_of(op->getResultTypes(), CanBeRefined)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
|
||||
<< op->getName() << "'.\n");
|
||||
return false;
|
||||
@ -565,8 +629,8 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
// This is necessary to avoid reprocessing the tf.Cast that are inserted at
|
||||
// the end of this function.
|
||||
if (isa<CastOp>(op) &&
|
||||
llvm::all_of(op->getResult(0).getUsers(), [&](Operation* user) {
|
||||
return user->getDialect() != tf_dialect;
|
||||
all_of(op->getResult(0).getUsers(), [&](Operation* user) {
|
||||
return user->getDialect() != tf_dialect_;
|
||||
})) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF "
|
||||
"dialect operation users '"
|
||||
@ -646,7 +710,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
// Perform the shape inference using an InferenceContext with the input
|
||||
// shapes. This object is abstracting the information that the ShapeInference
|
||||
// function operates on.
|
||||
InferenceContext c(graph_version, *node_def, op_reg_data->op_def,
|
||||
InferenceContext c(graph_version_, *node_def, op_reg_data->op_def,
|
||||
input_shapes, input_tensors,
|
||||
/*input_tensors_as_shapes=*/{}, handle_shapes_and_types);
|
||||
auto status = c.Run(op_reg_data->shape_inference_fn);
|
||||
@ -659,7 +723,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
// Determine if, during shape computation, the shape functions attempted to
|
||||
// query an input operand as shape where the input was not known/constant.
|
||||
bool requires_inputs =
|
||||
llvm::any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) {
|
||||
any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) {
|
||||
return c.requested_input_tensor_as_partial_shape(input) &&
|
||||
!input_tensors[input];
|
||||
});
|
||||
@ -723,7 +787,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
new_element_type.isa<TF::VariantType>()) {
|
||||
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
|
||||
if (handle_shapes_types) {
|
||||
llvm::SmallVector<mlir::TensorType, 1> subtypes;
|
||||
SmallVector<TensorType, 1> subtypes;
|
||||
OpBuilder b(op);
|
||||
for (const auto& shape_n_type : *handle_shapes_types) {
|
||||
Type element_type;
|
||||
@ -743,7 +807,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
if (result.getType() == new_type) continue;
|
||||
// Inserts a cast back to the original type if any user is not in the TF
|
||||
// dialect.
|
||||
AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect,
|
||||
AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect_,
|
||||
result.getType());
|
||||
// Finally we inferred the shape and replace the type for this result.
|
||||
result.setType(new_type);
|
||||
@ -755,23 +819,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
return changed;
|
||||
}
|
||||
|
||||
// Updates input types and refine shapes inside body of functions that are
|
||||
// attached to ControlFlow ops (If/While). These functions include Then/Else
|
||||
// branches of IfOp and Cond/Body functions of WhileOp. These functions share
|
||||
// following common properties:
|
||||
// 1) They are never reused, ie. having a single use in module.
|
||||
// 2) Their input types match those of their parent ops (excluding inputs like
|
||||
// predicate).
|
||||
// Returns a boolean indicating whether any change has been applied.
|
||||
LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
|
||||
llvm::ArrayRef<Type> input_types,
|
||||
int64_t graph_version,
|
||||
int64_t max_iteration) {
|
||||
LogicalResult ShapeInference::RefineShapeForControlFlowFunc(
|
||||
FuncOp func, ArrayRef<Type> input_types, int64_t max_iteration) {
|
||||
ModuleOp module = func.getParentOfType<ModuleOp>();
|
||||
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
|
||||
int num_uses = std::distance(func_uses->begin(), func_uses->end());
|
||||
if (num_uses != 1) {
|
||||
func.emitWarning(llvm::formatv(
|
||||
func.emitWarning(formatv(
|
||||
"expected control flow function {0} to have exactly 1 use, found {1}.",
|
||||
func.getName(), num_uses));
|
||||
return failure();
|
||||
@ -785,8 +839,7 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
|
||||
arg_and_idx.value().setType(input_types[arg_and_idx.index()]);
|
||||
}
|
||||
|
||||
auto res =
|
||||
InferShapeUntilFixPoint(&func.getBody(), graph_version, max_iteration);
|
||||
auto res = InferShapeUntilFixPoint(&func.getBody(), max_iteration);
|
||||
if (failed(res)) return res;
|
||||
|
||||
auto new_return_types = InferShapeForFunctionReturnType(func);
|
||||
@ -798,20 +851,18 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PropagateShapeToFunctions(
|
||||
LogicalResult ShapeInference::PropagateShapeToFunctions(
|
||||
ModuleOp module, Operation::operand_type_range input_types,
|
||||
llvm::ArrayRef<StringRef> func_names, int64_t graph_version,
|
||||
int64_t max_iteration) {
|
||||
bool success = true;
|
||||
ArrayRef<StringRef> func_names, int64_t max_iteration) {
|
||||
bool all_succeeded = true;
|
||||
auto types = llvm::to_vector<4>(input_types);
|
||||
for (auto func_name : func_names) {
|
||||
FuncOp func = module.lookupSymbol<FuncOp>(func_name);
|
||||
if (failed(RefineShapeForControlFlowFunc(func, types, graph_version,
|
||||
max_iteration))) {
|
||||
success = false;
|
||||
}
|
||||
all_succeeded =
|
||||
succeeded(RefineShapeForControlFlowFunc(func, types, max_iteration)) &&
|
||||
all_succeeded;
|
||||
}
|
||||
return mlir::success(success);
|
||||
return success(all_succeeded);
|
||||
}
|
||||
|
||||
// If the callee has only one use, propagates any constant operand of call_op to
|
||||
@ -831,7 +882,7 @@ void PropagateConstantToCallee(CallOpInterface call_op,
|
||||
// the constant inside the function.
|
||||
for (auto arg : func.getArguments()) {
|
||||
auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp();
|
||||
if (llvm::isa_and_nonnull<TF::ConstOp>(operand)) {
|
||||
if (isa_and_nonnull<TF::ConstOp>(operand)) {
|
||||
arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0));
|
||||
}
|
||||
}
|
||||
@ -850,33 +901,31 @@ void PropagateConstantFromCallee(CallOpInterface call_op,
|
||||
for (auto retval :
|
||||
llvm::enumerate(func.front().getTerminator()->getOperands())) {
|
||||
auto retval_op = retval.value().getDefiningOp();
|
||||
if (llvm::isa_and_nonnull<TF::ConstOp>(retval_op)) {
|
||||
if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
|
||||
op->getResult(retval.index())
|
||||
.replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
|
||||
int64_t graph_version,
|
||||
int64_t max_iteration) {
|
||||
LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
|
||||
Operation* op, int64_t max_iteration) {
|
||||
ModuleOp module = op->getParentOfType<ModuleOp>();
|
||||
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
|
||||
return PropagateShapeToFunctions(
|
||||
module, llvm::drop_begin(if_op.getOperandTypes(), 1),
|
||||
{if_op.then_branch(), if_op.else_branch()}, graph_version,
|
||||
max_iteration);
|
||||
module, drop_begin(if_op.getOperandTypes(), 1),
|
||||
{if_op.then_branch(), if_op.else_branch()}, max_iteration);
|
||||
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
|
||||
return PropagateShapeToFunctions(module, while_op.getOperandTypes(),
|
||||
{while_op.cond(), while_op.body()},
|
||||
graph_version, max_iteration);
|
||||
max_iteration);
|
||||
} else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
|
||||
CallInterfaceCallable callable = call_op.getCallableForCallee();
|
||||
if (SymbolRefAttr sym = callable.dyn_cast<SymbolRefAttr>()) {
|
||||
PropagateConstantToCallee(call_op, sym, module);
|
||||
if (failed(PropagateShapeToFunctions(
|
||||
module, call_op.getArgOperands().getTypes(),
|
||||
{sym.getRootReference()}, graph_version, max_iteration))) {
|
||||
{sym.getRootReference()}, max_iteration))) {
|
||||
return failure();
|
||||
}
|
||||
PropagateConstantFromCallee(call_op, sym, module);
|
||||
@ -889,13 +938,10 @@ LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
||||
int64_t max_iteration) {
|
||||
MLIRContext* ctx = region->getContext();
|
||||
Dialect* tf_dialect = ctx->getRegisteredDialect<TensorFlowDialect>();
|
||||
|
||||
// An operation folder that is used to attempt folding before inference.
|
||||
OperationFolder folder(ctx);
|
||||
LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
|
||||
int64_t max_iteration) {
|
||||
// An operation folder that is used to attempt folding before inference._
|
||||
OperationFolder folder(context_);
|
||||
bool changed = true;
|
||||
|
||||
// TODO(aminim): we could have a more efficient traversal by guiding the
|
||||
@ -908,14 +954,14 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
||||
<< "Shape inference, iteration " << iteration << "\n");
|
||||
region->walk([&](Operation* op) {
|
||||
if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
|
||||
changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect);
|
||||
changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_);
|
||||
// TODO(jpienaar): Debug why we can't just return here. We end up with
|
||||
// additional constant due to the propagation of constant into attached
|
||||
// function if we return already.
|
||||
}
|
||||
|
||||
if (op->getDialect() != tf_dialect) {
|
||||
changed |= InferShapeForNonTFDialectOperation(op, tf_dialect);
|
||||
if (op->getDialect() != tf_dialect_) {
|
||||
changed |= InferShapeForNonTFDialectOperation(op, tf_dialect_);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -924,13 +970,12 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
||||
|
||||
// Best-effort shape inference in attached functions. Do not return
|
||||
// failure even if it doesn't get to fixed point.
|
||||
if (failed(PropagateShapeIntoAttachedFunctions(op, graph_version,
|
||||
max_iteration))) {
|
||||
if (failed(PropagateShapeIntoAttachedFunctions(op, max_iteration))) {
|
||||
op->emitWarning() << "unable to refine shape of attached function "
|
||||
"arguments and bodies";
|
||||
}
|
||||
|
||||
changed |= InferShapeForSingleOperation(op, tf_dialect, graph_version);
|
||||
changed |= InferShapeForSingleOperation(op);
|
||||
});
|
||||
}
|
||||
|
||||
@ -945,31 +990,43 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
||||
LogicalResult InferShapeForFunction(FuncOp func,
|
||||
ArrayRef<ArrayRef<int64_t>> arg_shapes,
|
||||
int64_t graph_version) {
|
||||
mlir::FunctionType func_type = func.getType();
|
||||
ShapeInference context(graph_version, func.getContext());
|
||||
if (arg_shapes.empty()) {
|
||||
if (failed(context.InferShapeUntilFixPoint(&func.getBody())))
|
||||
return failure();
|
||||
// TODO(b/156276510): Verify that it is always fine to refine a function's
|
||||
// return type, as long as we do not change the argument shapes.
|
||||
if (auto return_types = InferShapeForFunctionReturnType(func)) {
|
||||
func.setType(FunctionType::get(func.getType().getInputs(),
|
||||
return_types.getValue(),
|
||||
func.getContext()));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
FunctionType func_type = func.getType();
|
||||
bool needs_refinement = false;
|
||||
llvm::SmallVector<mlir::Type, 4> new_arg_types;
|
||||
SmallVector<Type, 4> new_arg_types;
|
||||
new_arg_types.reserve(func_type.getNumInputs());
|
||||
|
||||
// Update argument types in-place using the provided arg_shapes.
|
||||
for (size_t i = 0; i < func_type.getNumInputs(); ++i) {
|
||||
ArrayRef<int64_t> shape = arg_shapes[i];
|
||||
mlir::Type element_type;
|
||||
if (auto input_ty =
|
||||
func_type.getInput(i).dyn_cast<mlir::RankedTensorType>()) {
|
||||
Type element_type;
|
||||
if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
|
||||
if (!input_ty || input_ty.getShape().size() != shape.size()) {
|
||||
return failure();
|
||||
}
|
||||
element_type = input_ty.getElementType();
|
||||
} else {
|
||||
auto unranked_input_ty =
|
||||
func_type.getInput(i).dyn_cast<mlir::TensorType>();
|
||||
auto unranked_input_ty = func_type.getInput(i).dyn_cast<TensorType>();
|
||||
if (!unranked_input_ty) {
|
||||
return failure();
|
||||
}
|
||||
element_type = unranked_input_ty.getElementType();
|
||||
}
|
||||
|
||||
auto new_arg_type = mlir::RankedTensorType::get(shape, element_type);
|
||||
auto new_arg_type = RankedTensorType::get(shape, element_type);
|
||||
if (new_arg_type != func_type.getInput(i)) {
|
||||
// If the new type is more detailed, trigger shape inference.
|
||||
func.getArgument(i).setType(new_arg_type);
|
||||
@ -982,28 +1039,17 @@ LogicalResult InferShapeForFunction(FuncOp func,
|
||||
return success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult result =
|
||||
mlir::TF::InferShapeUntilFixPoint(&func.getBody(), graph_version);
|
||||
LogicalResult result = context.InferShapeUntilFixPoint(&func.getBody());
|
||||
if (failed(result)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto return_types = InferShapeForFunctionReturnType(func);
|
||||
func.setType(mlir::FunctionType::get(new_arg_types,
|
||||
return_types.hasValue()
|
||||
? return_types.getValue()
|
||||
: func.getType().getResults(),
|
||||
func.getContext()));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult InferShapeForFunctionType(FuncOp func) {
|
||||
if (auto return_types = InferShapeForFunctionReturnType(func)) {
|
||||
func.setType(mlir::FunctionType::get(func.getType().getInputs(),
|
||||
return_types.getValue(),
|
||||
func.getContext()));
|
||||
}
|
||||
func.setType(FunctionType::get(new_arg_types,
|
||||
return_types.hasValue()
|
||||
? return_types.getValue()
|
||||
: func.getType().getResults(),
|
||||
func.getContext()));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -27,30 +27,13 @@ namespace mlir {
|
||||
|
||||
namespace TF {
|
||||
|
||||
// Performs shape inference on the provided op and return true if the type of
|
||||
// at least one result has been changed.
|
||||
// A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
|
||||
// `graph_version` indicates the current GraphDef compatibility versions
|
||||
// (the versions field in graph.proto).
|
||||
bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
int64_t graph_version);
|
||||
|
||||
// Infers shape on the provided region, including nested ones, iterate until fix
|
||||
// point with a limit of max_iteration. Returns success if fix point is reached
|
||||
// before max_iteration.
|
||||
LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
|
||||
int64_t max_iteration = 10);
|
||||
|
||||
// Given a list of refined shapes matching the function arguments of func, runs
|
||||
// shape inference over the function to propagate this updated information.
|
||||
// If arg_shapes are empty, then argument shapes will be left unchanged.
|
||||
LogicalResult InferShapeForFunction(FuncOp func,
|
||||
ArrayRef<ArrayRef<int64_t>> arg_shapes,
|
||||
int64_t graph_version);
|
||||
|
||||
// Refines the return type of the given function by folding tf.Cast that
|
||||
// precedes the return instruction.
|
||||
LogicalResult InferShapeForFunctionType(FuncOp func);
|
||||
|
||||
} // namespace TF
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -58,10 +58,8 @@ struct ShapeInference
|
||||
}
|
||||
int64_t producer = producer_or.ValueOrDie();
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
InferShapeUntilFixPoint(&func.getBody(), producer);
|
||||
// TODO(yuanzx): Verify that it is always fine to refine a function's
|
||||
// return type, as long as we do not change the argument shapes.
|
||||
InferShapeForFunctionType(func);
|
||||
if (failed(InferShapeForFunction(func, /*arg_shapes=*/{}, producer)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -14,11 +14,23 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
@ -30,30 +42,182 @@ namespace {
|
||||
|
||||
constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
|
||||
|
||||
struct TPUExtractHeadTailOutsideCompilation
|
||||
: public PassWrapper<TPUExtractHeadTailOutsideCompilation, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
bool HasOutsideCompilationAttribute(Operation* op) {
|
||||
return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
|
||||
}
|
||||
|
||||
void TPUExtractHeadTailOutsideCompilation::runOnFunction() {
|
||||
getFunction().walk([&](tf_device::LaunchOp launch) {
|
||||
Block& launch_block = launch.GetBody();
|
||||
for (auto& op : llvm::make_early_inc_range(launch_block.getOperations())) {
|
||||
// TODO(b/155115766): Handle outputs that should be inputs to TPU
|
||||
// LaunchOp.
|
||||
if (auto attr =
|
||||
op.getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
|
||||
op.moveBefore(launch);
|
||||
} else {
|
||||
// Returns whether all operands of `op` are from values inside the
|
||||
// `input_value_set`.
|
||||
bool OpContainsOperandsFromSet(Operation* op,
|
||||
const llvm::SetVector<Value>& input_value_set) {
|
||||
for (auto operand : op->getOperands())
|
||||
if (input_value_set.count(operand) == 0) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void RecordOutsideCompiledOpsAndUsages(
|
||||
Operation* op, llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops,
|
||||
llvm::SetVector<Value>* outside_compiled_op_usages) {
|
||||
if (HasOutsideCompilationAttribute(op) &&
|
||||
OpContainsOperandsFromSet(op, *outside_compiled_op_usages)) {
|
||||
outside_compiled_ops->insert(op);
|
||||
outside_compiled_op_usages->insert(op->getResults().begin(),
|
||||
op->getResults().end());
|
||||
}
|
||||
}
|
||||
|
||||
// Traverses the MLIR graph and returns a set of ops that
|
||||
// are connected to inputs of TPU computation and outside compiled.
|
||||
void ExtractOutsideCompiledOpsConnectedToHead(
|
||||
Value input_value, llvm::SetVector<Value>* values_used_in_host_cluster,
|
||||
llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops) {
|
||||
llvm::SmallSetVector<Operation*, 4> parent_outside_compiled_ops_at_head;
|
||||
for (auto& usage : input_value.getUses()) {
|
||||
auto head_operation = usage.getOwner();
|
||||
RecordOutsideCompiledOpsAndUsages(head_operation,
|
||||
&parent_outside_compiled_ops_at_head,
|
||||
values_used_in_host_cluster);
|
||||
}
|
||||
|
||||
// Traverse the graph and find all outside compiled ops connected from
|
||||
// the `input_value`.
|
||||
while (!parent_outside_compiled_ops_at_head.empty()) {
|
||||
llvm::SmallSetVector<Operation*, 4> connected_outside_compiled_ops;
|
||||
for (auto head_outside_compiled_op : parent_outside_compiled_ops_at_head) {
|
||||
auto op_results = head_outside_compiled_op->getOpResults();
|
||||
for (auto op_result : op_results) {
|
||||
for (auto& use : op_result.getUses()) {
|
||||
auto connected_op = use.getOwner();
|
||||
RecordOutsideCompiledOpsAndUsages(connected_op,
|
||||
&connected_outside_compiled_ops,
|
||||
values_used_in_host_cluster);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
outside_compiled_ops->insert(parent_outside_compiled_ops_at_head.begin(),
|
||||
parent_outside_compiled_ops_at_head.end());
|
||||
std::swap(parent_outside_compiled_ops_at_head,
|
||||
connected_outside_compiled_ops);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(hongjunchoi): Also handle ops without inputs that are outside
|
||||
// compiled.
|
||||
//
|
||||
// Returns set of ops that are outside compiled and are directly connected
|
||||
// to inputs to the TPU computation.
|
||||
llvm::SmallSetVector<Operation*, 4> IdentifyOutsideCompiledOpsAtHead(
|
||||
tf_device::ClusterOp tpu_cluster) {
|
||||
llvm::SmallSetVector<Operation*, 4> outside_compiled_at_head_ops;
|
||||
llvm::SetVector<Value> values_used_in_cluster;
|
||||
auto& cluster_region = tpu_cluster.body();
|
||||
getUsedValuesDefinedAbove(cluster_region, cluster_region,
|
||||
values_used_in_cluster);
|
||||
|
||||
auto input_value_list = llvm::to_vector<8>(values_used_in_cluster);
|
||||
for (auto input_value : input_value_list)
|
||||
ExtractOutsideCompiledOpsConnectedToHead(
|
||||
input_value, &values_used_in_cluster, &outside_compiled_at_head_ops);
|
||||
return outside_compiled_at_head_ops;
|
||||
}
|
||||
|
||||
// Returns output values of extracted outside compiled cluster at head that
|
||||
// are used by the TPU computation.
|
||||
llvm::SmallVector<Value, 8> GetHeadExtractedClusterOutputs(
|
||||
const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
|
||||
llvm::SmallVector<Value, 8> outputs;
|
||||
outputs.reserve(head_outside_compiled_ops.size());
|
||||
|
||||
for (auto op : head_outside_compiled_ops) {
|
||||
for (Operation* user : op->getUsers()) {
|
||||
if (!head_outside_compiled_ops.count(user)) {
|
||||
outputs.append(op->result_begin(), op->result_end());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
// Creates new tf_device.launch op with outside compiled ops extracted
|
||||
// from the head of TPU computation.
|
||||
llvm::Optional<tf_device::LaunchOp> IsolateHeadExtractedOpsToLaunchOp(
|
||||
OpBuilder* builder, tf_device::ClusterOp cluster,
|
||||
const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
|
||||
if (head_outside_compiled_ops.empty())
|
||||
return llvm::Optional<tf_device::LaunchOp>();
|
||||
|
||||
// Create tf_device.launch op to separate all extracted outside compiled ops
|
||||
// before the tf_device.cluster.
|
||||
auto output_values =
|
||||
GetHeadExtractedClusterOutputs(head_outside_compiled_ops);
|
||||
|
||||
llvm::SmallVector<Type, 8> output_return_types;
|
||||
output_return_types.reserve(output_values.size());
|
||||
for (auto output : output_values)
|
||||
output_return_types.emplace_back(output.getType());
|
||||
|
||||
builder->setInsertionPoint(cluster);
|
||||
auto host_launch_op = builder->create<tf_device::LaunchOp>(
|
||||
cluster.getLoc(), builder->getStringAttr(""), output_return_types);
|
||||
|
||||
// Replace all usages of outside compiled ops that are used in TPU
|
||||
// computation with the results of the above created launch op.
|
||||
for (auto output_and_index : llvm::enumerate(output_values)) {
|
||||
auto output_index = output_and_index.index();
|
||||
auto output = output_and_index.value();
|
||||
for (auto& use : output.getUses()) {
|
||||
if (!head_outside_compiled_ops.count(use.getOwner()))
|
||||
use.set(host_launch_op.getResult(output_index));
|
||||
}
|
||||
}
|
||||
|
||||
// Create terminator op for the newly created launch op.
|
||||
host_launch_op.body().push_back(new Block());
|
||||
builder->setInsertionPointToEnd(&host_launch_op.GetBody());
|
||||
auto terminator = builder->create<tf_device::ReturnOp>(
|
||||
host_launch_op.getLoc(), output_values);
|
||||
|
||||
// Move all outside compile ops from cluster op to launch op.
|
||||
for (auto outside_compiled_op : head_outside_compiled_ops)
|
||||
outside_compiled_op->moveBefore(terminator);
|
||||
|
||||
return host_launch_op;
|
||||
}
|
||||
|
||||
struct TPUExtractHeadTailOutsideCompilation
|
||||
: public PassWrapper<TPUExtractHeadTailOutsideCompilation,
|
||||
OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
|
||||
// Get runtime devices information from the closest parent module.
|
||||
auto module = getOperation();
|
||||
mlir::TF::RuntimeDevices devices;
|
||||
if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
|
||||
return signalPassFailure();
|
||||
|
||||
OpBuilder builder(&getContext());
|
||||
module.walk([&](tf_device::ClusterOp cluster) {
|
||||
auto head_outside_compiled_ops = IdentifyOutsideCompiledOpsAtHead(cluster);
|
||||
IsolateHeadExtractedOpsToLaunchOp(&builder, cluster,
|
||||
head_outside_compiled_ops);
|
||||
|
||||
// TODO(b/156030523): Update device attribute of newly created host launch
|
||||
// op as well as enclosing Replicate op (if TPU computation is replicated)
|
||||
// with host device names.
|
||||
|
||||
// TODO(b/155115766): Implement tail outside compiled op extraction.
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateTPUExtractHeadTailOutsideCompilationPass() {
|
||||
return std::make_unique<TPUExtractHeadTailOutsideCompilation>();
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
|
||||
// Mapping for `_xla_outside_compilation` attribute to ops of a cluster.
|
||||
using ClusterMap =
|
||||
using OutsideClusterMap =
|
||||
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<Operation*, 8>, 8>;
|
||||
|
||||
// This pass extracts a CPU computation cluster with `_xla_outside_compilation`
|
||||
@ -51,7 +51,8 @@ struct TPUExtractOutsideCompilation
|
||||
// Collects and clusters ops in `block` with the same `_xla_outside_compilation`
|
||||
// attribute into `clusters` This returns an error if a
|
||||
// `_xla_outside_compilation` attribute of an op is empty.
|
||||
LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) {
|
||||
LogicalResult CollectAndGroupOutsideClusterOps(Block* block,
|
||||
OutsideClusterMap* clusters) {
|
||||
for (Operation& op : *block) {
|
||||
if (auto attr = op.getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
|
||||
if (attr.getValue().empty())
|
||||
@ -67,7 +68,7 @@ LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) {
|
||||
}
|
||||
|
||||
// Moves `cluster_ops` to associated `launch_op` body.
|
||||
void MoveClusterOpsToLaunchOp(
|
||||
void MoveOutsideClusterOpsToLaunchOp(
|
||||
tf_device::LaunchOp launch_op,
|
||||
const llvm::SmallVector<Operation*, 8>& cluster_ops) {
|
||||
MLIRContext* context = launch_op.getContext();
|
||||
@ -84,8 +85,8 @@ void MoveClusterOpsToLaunchOp(
|
||||
}
|
||||
|
||||
// Creates a `tf_device::LaunchOp` to wrap cluster ops.
|
||||
tf_device::LaunchOp CreateLaunchOpForCluster(OpBuilder* builder,
|
||||
Operation* last_cluster_op) {
|
||||
tf_device::LaunchOp CreateLaunchOpForOutsideCluster(
|
||||
OpBuilder* builder, Operation* last_cluster_op) {
|
||||
// TODO(b/154363171): Set the CPU device.
|
||||
// An empty string placeholder is used for the device as that will be later
|
||||
// populated with the device of the associated TPUReplicateMetadata op.
|
||||
@ -117,14 +118,14 @@ void PropagateParallelExecuteReturnToReplicate(
|
||||
|
||||
// Creates a `parallel_execute` op in place of launch with 'clusters` and
|
||||
// 'launch` as regions.
|
||||
void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch,
|
||||
const ClusterMap& clusters) {
|
||||
OpBuilder builder(launch);
|
||||
void CreateParallelExecuteFromOutsideClusters(
|
||||
tf_device::ClusterOp tpu_cluster, const OutsideClusterMap& clusters) {
|
||||
OpBuilder builder(tpu_cluster);
|
||||
// Create parallel_execute regions. The original TPU cluster computation
|
||||
// is the extra region.
|
||||
int num_regions = 1 + clusters.size();
|
||||
auto parallel_execute_op = builder.create<tf_device::ParallelExecuteOp>(
|
||||
launch.getLoc(), num_regions, launch.results().getTypes());
|
||||
tpu_cluster.getLoc(), num_regions, tpu_cluster.results().getTypes());
|
||||
|
||||
// Move outside compilation clusters to parallel_execute regions.
|
||||
for (const auto& cluster : llvm::enumerate(clusters)) {
|
||||
@ -134,21 +135,23 @@ void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch,
|
||||
parallel_execute_op.GetRegionBlockWithIndex(cluster.index());
|
||||
builder.setInsertionPointToEnd(&outside_block);
|
||||
tf_device::LaunchOp launch_op =
|
||||
CreateLaunchOpForCluster(&builder, cluster_ops.back());
|
||||
MoveClusterOpsToLaunchOp(launch_op, cluster_ops);
|
||||
CreateLaunchOpForOutsideCluster(&builder, cluster_ops.back());
|
||||
MoveOutsideClusterOpsToLaunchOp(launch_op, cluster_ops);
|
||||
builder.setInsertionPointToEnd(&outside_block);
|
||||
// TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute
|
||||
// regions either through communication with TPU parallel_execute regions
|
||||
// or modifying parallel_execute returns.
|
||||
builder.create<tf_device::ReturnOp>(launch.getLoc(), ArrayRef<Value>{});
|
||||
builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(),
|
||||
ArrayRef<Value>{});
|
||||
}
|
||||
|
||||
// Move the launch body to last parallel_execute block.
|
||||
Block& inside_block =
|
||||
parallel_execute_op.GetRegionBlockWithIndex(num_regions - 1);
|
||||
builder.setInsertionPointToEnd(&inside_block);
|
||||
builder.create<tf_device::ReturnOp>(launch.getLoc(), launch.getResults());
|
||||
launch.getOperation()->moveBefore(inside_block.getTerminator());
|
||||
builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(),
|
||||
tpu_cluster.getResults());
|
||||
tpu_cluster.getOperation()->moveBefore(inside_block.getTerminator());
|
||||
|
||||
PropagateParallelExecuteReturnToReplicate(parallel_execute_op);
|
||||
// TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute
|
||||
@ -157,17 +160,19 @@ void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch,
|
||||
}
|
||||
|
||||
void TPUExtractOutsideCompilation::runOnFunction() {
|
||||
auto extract_result = getFunction().walk([&](tf_device::LaunchOp launch) {
|
||||
ClusterMap clusters;
|
||||
if (failed(CollectAndGroupClusterOps(&launch.GetBody(), &clusters)))
|
||||
return WalkResult::interrupt();
|
||||
auto extract_result =
|
||||
getFunction().walk([&](tf_device::ClusterOp tpu_cluster) {
|
||||
OutsideClusterMap clusters;
|
||||
if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
|
||||
&clusters)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
if (clusters.empty()) return WalkResult::advance();
|
||||
if (clusters.empty()) return WalkResult::advance();
|
||||
|
||||
CreateParallelExecuteFromClusters(launch, clusters);
|
||||
CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters);
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (extract_result.wasInterrupted()) return signalPassFailure();
|
||||
}
|
||||
|
@ -92,7 +92,7 @@ constexpr char kBadArrayAttrLengthMsg[] =
|
||||
//
|
||||
// Would become following ops (unimportant attributes, types are omitted):
|
||||
// %1 = "tf.Shape"(%0)
|
||||
// %2:2 = "tf.MLIRCompileToTPU"(%1) {module = "<Serialized @tpu_func>"}
|
||||
// %2:2 = "tf._TPUCompileMlir"(%1) {module = "<Serialized @tpu_func>"}
|
||||
// "tf.TPUCompileSucceededAssert"(%2#0)
|
||||
// %3 = "tf.TPUExecute"(%0, %2#1)
|
||||
// %4 = "tf.SomeOp"(%3)
|
||||
@ -448,19 +448,20 @@ Operation* BuildCompileOp(
|
||||
// core, and all replica devices per core are grouped together.
|
||||
void AssignDevicesToReplicate(
|
||||
tf_device::ReplicateOp replicate,
|
||||
llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
|
||||
llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
|
||||
tpu_devices,
|
||||
OpBuilder* builder) {
|
||||
if (!replicate) return;
|
||||
|
||||
const int num_replicas = execution_devices.size();
|
||||
const int num_cores_per_replica = execution_devices.front().size();
|
||||
const int num_replicas = tpu_devices.size();
|
||||
const int num_cores_per_replica = tpu_devices.front().size();
|
||||
|
||||
llvm::SmallVector<NamedAttribute, 8> device_attrs;
|
||||
for (int core = 0; core < num_cores_per_replica; ++core) {
|
||||
llvm::SmallVector<StringRef, 8> devices_by_core;
|
||||
devices_by_core.reserve(num_replicas);
|
||||
for (int replica = 0; replica < num_replicas; ++replica)
|
||||
devices_by_core.push_back(execution_devices[replica][core]);
|
||||
devices_by_core.push_back(tpu_devices[replica][core].device);
|
||||
|
||||
device_attrs.push_back(
|
||||
builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core),
|
||||
@ -492,11 +493,12 @@ LogicalResult BuildExecuteOp(
|
||||
// Creates a tf_device.parallel_execute op that wraps TPUExecute op to
|
||||
// represent execution of TPU program in multiple logical cores.
|
||||
LogicalResult BuildParallelExecuteOp(
|
||||
llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
|
||||
llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
|
||||
tpu_devices,
|
||||
llvm::ArrayRef<xla::OpSharding> output_sharding_config,
|
||||
Operation* compile_op, tf_device::ClusterFuncOp cluster_func,
|
||||
OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) {
|
||||
const int num_cores_per_replica = execution_devices.front().size();
|
||||
const int num_cores_per_replica = tpu_devices.front().size();
|
||||
// parallel_execute op returns concatenated list of return values of
|
||||
// all its regions.
|
||||
//
|
||||
@ -528,7 +530,7 @@ LogicalResult BuildParallelExecuteOp(
|
||||
num_cores_per_replica, cluster_func, builder, &input_list);
|
||||
if (failed(result)) return failure();
|
||||
|
||||
const bool replicated = execution_devices.size() != 1;
|
||||
const bool replicated = tpu_devices.size() != 1;
|
||||
// For each logical core, create a region with TPUExecute op.
|
||||
assert(input_list.size() == num_cores_per_replica);
|
||||
for (int core = 0; core < num_cores_per_replica; ++core) {
|
||||
@ -553,7 +555,7 @@ LogicalResult BuildParallelExecuteOp(
|
||||
// op.
|
||||
std::string device = replicated
|
||||
? tensorflow::GetDeviceAliasForLogicalCore(core)
|
||||
: execution_devices.front()[core];
|
||||
: tpu_devices.front()[core].device;
|
||||
|
||||
auto region_launch_op =
|
||||
WrapOpInLaunch(builder, region.getParent()->getLoc(), execute, device);
|
||||
@ -566,13 +568,14 @@ LogicalResult BuildParallelExecuteOp(
|
||||
}
|
||||
|
||||
tf_device::LaunchOp AssignDevicesToReplicatedExecute(
|
||||
llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
|
||||
llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
|
||||
tpu_devices,
|
||||
Operation* execute_op, OpBuilder* builder) {
|
||||
const bool replicated = execution_devices.size() != 1;
|
||||
const bool replicated = tpu_devices.size() != 1;
|
||||
// If computation is replicated, use aliased device. Otherwise there is only
|
||||
// one execution device and the device is assigned to the execute op.
|
||||
std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0)
|
||||
: execution_devices.front().front();
|
||||
: tpu_devices.front().front().device;
|
||||
|
||||
return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device);
|
||||
}
|
||||
@ -687,6 +690,16 @@ LogicalResult Rewrite(
|
||||
// Create compile op.
|
||||
auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
|
||||
builder->setInsertionPoint(cluster_func);
|
||||
|
||||
// Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
|
||||
// parallel_execute region if it exists.
|
||||
if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func.getParentOp())) {
|
||||
// Currently, outside compilation and model parallelism are not supported
|
||||
// together.
|
||||
assert(num_cores_per_replica == 1);
|
||||
builder->setInsertionPoint(cluster_func.getParentOp());
|
||||
}
|
||||
|
||||
Operation* compile_op = BuildCompileOp(
|
||||
cluster_func, num_replicas, num_cores_per_replica,
|
||||
tpu_device_assignment.compilation_device,
|
||||
@ -704,7 +717,7 @@ LogicalResult Rewrite(
|
||||
BuildTPUCompileSucceededAssertOp(
|
||||
compile_op, tpu_device_assignment.compilation_device, builder);
|
||||
|
||||
AssignDevicesToReplicate(replicate, tpu_device_assignment.execution_devices,
|
||||
AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices,
|
||||
builder);
|
||||
|
||||
llvm::SmallVector<xla::OpSharding, 4> output_shardings;
|
||||
@ -712,12 +725,13 @@ LogicalResult Rewrite(
|
||||
num_cores_per_replica, cluster_func, &output_shardings);
|
||||
if (failed(result)) return failure();
|
||||
|
||||
builder->setInsertionPoint(cluster_func);
|
||||
if (num_cores_per_replica > 1) {
|
||||
// For model parallelism, tf_device.parallel_execute is used to express
|
||||
// concurrent device execution across multiple logical devices.
|
||||
|
||||
tf_device::ParallelExecuteOp execute_op;
|
||||
result = BuildParallelExecuteOp(tpu_device_assignment.execution_devices,
|
||||
result = BuildParallelExecuteOp(tpu_device_assignment.tpu_devices,
|
||||
output_shardings, compile_op, cluster_func,
|
||||
builder, &execute_op);
|
||||
if (failed(result)) return failure();
|
||||
@ -740,7 +754,7 @@ LogicalResult Rewrite(
|
||||
if (failed(result)) return failure();
|
||||
|
||||
tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute(
|
||||
tpu_device_assignment.execution_devices, execute_op, builder);
|
||||
tpu_device_assignment.tpu_devices, execute_op, builder);
|
||||
cluster_func.replaceAllUsesWith(launch_op);
|
||||
}
|
||||
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
@ -57,6 +58,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/IR/Verifier.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
||||
@ -65,6 +67,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
@ -109,6 +112,7 @@ static inline absl::string_view StringRefToView(llvm::StringRef ref) {
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
using mlir::NamedAttrList;
|
||||
using mlir::TensorType;
|
||||
using mlir::TF::VarHandleOp;
|
||||
using mlir::tf_saved_model::GlobalTensorOp;
|
||||
@ -306,9 +310,9 @@ class ImporterBase {
|
||||
// AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to
|
||||
// a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar},
|
||||
// {base_name.k2 : rfc}}.
|
||||
Status ConvertFunctionCallAttribute(
|
||||
const std::string& base_name, const AttrValue& value,
|
||||
llvm::SmallVector<mlir::NamedAttribute, 4>* attributes);
|
||||
Status ConvertFunctionCallAttribute(const std::string& base_name,
|
||||
const AttrValue& value,
|
||||
NamedAttrList* attributes);
|
||||
|
||||
// Helper to create either a tf_executor operation or a TF operation wrapped
|
||||
// in an island. When convert_to_legacy_call is true, converts the operation
|
||||
@ -1089,9 +1093,9 @@ StatusOr<ImporterBase::ElementSubtypes> ImporterBase::ConvertSubtypes(
|
||||
return subtypes;
|
||||
}
|
||||
|
||||
Status ImporterBase::ConvertFunctionCallAttribute(
|
||||
const std::string& base_name, const AttrValue& value,
|
||||
llvm::SmallVector<mlir::NamedAttribute, 4>* attributes) {
|
||||
Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name,
|
||||
const AttrValue& value,
|
||||
NamedAttrList* attributes) {
|
||||
TF_ASSIGN_OR_RETURN(auto func_attr,
|
||||
ConvertFunctionCallName(value.func().name()));
|
||||
attributes->push_back(builder_.getNamedAttr(base_name, func_attr));
|
||||
@ -2428,8 +2432,8 @@ class SavedModelObjectGraphImporter : public ImporterBase {
|
||||
// Main entry point: converts all functions in the given meta graph to an MLIR
|
||||
// Module.
|
||||
static StatusOr<mlir::OwningModuleRef> Convert(
|
||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
||||
absl::Span<std::string> exported_names, bool add_default_attributes);
|
||||
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context, bool add_default_attributes);
|
||||
|
||||
private:
|
||||
explicit SavedModelObjectGraphImporter(
|
||||
@ -3129,8 +3133,8 @@ Status CreateSavedModelIR(
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
|
||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
||||
absl::Span<std::string> exported_names, bool add_default_attributes) {
|
||||
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context, bool add_default_attributes) {
|
||||
GraphDebugInfo dummy_debug_info;
|
||||
const GraphDebugInfo& debug_info =
|
||||
saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
|
||||
@ -3207,17 +3211,20 @@ class SavedModelSignatureDefImporter {
|
||||
public:
|
||||
// Main entry point: converts all functions (specified by SignatureDefs) in
|
||||
// the given meta graph to an MLIR Module.
|
||||
static StatusOr<mlir::OwningModuleRef> Convert(const SavedModelBundle& bundle,
|
||||
mlir::MLIRContext* context) {
|
||||
SavedModelSignatureDefImporter importer(bundle, context);
|
||||
static StatusOr<mlir::OwningModuleRef> Convert(
|
||||
const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context) {
|
||||
SavedModelSignatureDefImporter importer(bundle, exported_names, context);
|
||||
|
||||
return importer.ConvertSignatures();
|
||||
}
|
||||
|
||||
private:
|
||||
SavedModelSignatureDefImporter(const SavedModelBundle& bundle,
|
||||
absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context)
|
||||
: bundle_(bundle),
|
||||
exported_names_(exported_names),
|
||||
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {}
|
||||
|
||||
// Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
|
||||
@ -3250,6 +3257,7 @@ class SavedModelSignatureDefImporter {
|
||||
const std::vector<std::pair<std::string, TensorInfo>>& inputs);
|
||||
|
||||
const SavedModelBundle& bundle_;
|
||||
absl::Span<std::string> exported_names_;
|
||||
mlir::OwningModuleRef module_;
|
||||
};
|
||||
|
||||
@ -3265,6 +3273,9 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
|
||||
GraphDebugInfo debug_info;
|
||||
if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info;
|
||||
|
||||
llvm::StringSet<> exported_name_set;
|
||||
exported_name_set.insert(exported_names_.begin(), exported_names_.end());
|
||||
|
||||
for (const auto& key_and_signature_def : signatures) {
|
||||
const std::string& sig_def_key = key_and_signature_def.first;
|
||||
const SignatureDef& signature_def = key_and_signature_def.second;
|
||||
@ -3274,6 +3285,10 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
|
||||
if (sig_def_key == "__saved_model_init_op") {
|
||||
continue;
|
||||
}
|
||||
if (!exported_name_set.empty() &&
|
||||
exported_name_set.count(sig_def_key) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def,
|
||||
debug_info, flib_def));
|
||||
@ -3556,12 +3571,14 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
||||
absl::Span<std::string> exported_names, bool add_default_attributes) {
|
||||
return SavedModelObjectGraphImporter::Convert(
|
||||
saved_model, context, exported_names, add_default_attributes);
|
||||
saved_model, exported_names, context, add_default_attributes);
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
|
||||
const SavedModelBundle& saved_model, mlir::MLIRContext* context) {
|
||||
return SavedModelSignatureDefImporter::Convert(saved_model, context);
|
||||
const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context) {
|
||||
return SavedModelSignatureDefImporter::Convert(saved_model, exported_names,
|
||||
context);
|
||||
}
|
||||
|
||||
std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
|
||||
|
@ -55,6 +55,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
||||
// expressed with tf_executor dialect.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef>
|
||||
ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
|
||||
absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Serialize a MLIR module to a string.
|
||||
|
@ -141,7 +141,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
||||
|
||||
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context) {
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context) {
|
||||
tensorflow::SavedModelBundle bundle;
|
||||
tensorflow::SessionOptions session_options;
|
||||
// Force saved model states to be restored to CPU.
|
||||
@ -155,7 +156,7 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, context);
|
||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context);
|
||||
if (!module_or.status().ok()) {
|
||||
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
|
||||
return nullptr;
|
||||
|
@ -64,7 +64,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport(
|
||||
// given MLIR `context`.
|
||||
mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport(
|
||||
absl::string_view saved_model_dir,
|
||||
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context);
|
||||
const std::unordered_set<std::string>& tags,
|
||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -293,6 +293,12 @@ Status ConvertMLIRToXlaComputation(
|
||||
tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type));
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
|
||||
// Run shape inference pass to propagate shapes through tensor_cast operations
|
||||
// from static to dynamic shapes. This could be generated if the shape
|
||||
// inference was originally missing in a TF op but the corresponding HLO op
|
||||
// had static shape after lowering.
|
||||
tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
|
||||
// Run LegalizeTFPass again because the previous legalization passes can
|
||||
// expose more graph pruning and canonicalization opportunities that are
|
||||
// necessary for the second LegalizeTFPass(allow_partial_conversion=false)
|
||||
|
@ -31,12 +31,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
@ -131,13 +133,21 @@ StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
|
||||
case DTYPE: \
|
||||
return ConvertFlatTensor<CTYPE>(input_tensor, type);
|
||||
|
||||
// TODO(fengliuai): customize the conversions for more types.
|
||||
// TODO(fengliuai): customize the conversions for quantized and string types.
|
||||
switch (input_dtype) {
|
||||
CONVERT_FLAT(DT_BOOL, bool)
|
||||
CONVERT_FLAT(DT_FLOAT, float)
|
||||
CONVERT_FLAT(DT_DOUBLE, double)
|
||||
CONVERT_FLAT(DT_INT8, int8)
|
||||
CONVERT_FLAT(DT_INT16, int16)
|
||||
CONVERT_FLAT(DT_INT32, int32)
|
||||
CONVERT_FLAT(DT_INT64, int64)
|
||||
CONVERT_FLAT(DT_UINT8, uint8)
|
||||
CONVERT_FLAT(DT_UINT16, uint16)
|
||||
CONVERT_FLAT(DT_UINT32, uint32)
|
||||
CONVERT_FLAT(DT_UINT64, uint64)
|
||||
CONVERT_FLAT(DT_COMPLEX64, std::complex<float>)
|
||||
CONVERT_FLAT(DT_COMPLEX128, std::complex<double>)
|
||||
|
||||
// BFLOAT16 is a special case that it needs to be cast to double type to
|
||||
// match its storage type.
|
||||
@ -207,12 +217,20 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) {
|
||||
|
||||
// Converts an MLIR dense string elements attribute to a TensorFlow tensor
|
||||
// proto.
|
||||
Status ConvertStringElementsAttr(const DenseStringElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
for (const auto& val : attr.getRawStringData()) {
|
||||
output_tensor->add_string_val(val.data(), val.size());
|
||||
void ConvertStringElementsAttr(
|
||||
const DenseStringElementsAttr attr,
|
||||
protobuf::RepeatedPtrField<std::string>* output) {
|
||||
for (const auto& val : attr.getRawStringData())
|
||||
output->Add({val.data(), val.size()});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr,
|
||||
protobuf::RepeatedField<T>* output) {
|
||||
for (const auto& val : attr.getValues<std::complex<T>>()) {
|
||||
output->Add(val.real());
|
||||
output->Add(val.imag());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
|
||||
@ -226,139 +244,80 @@ Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
|
||||
return InvalidArgument("Unexpected elements attribute type from MLIR.");
|
||||
}
|
||||
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||
// with the double_val field updated.
|
||||
Status ConvertDoubleElementsAttr(const ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
if (elts.isSplat()) {
|
||||
output_tensor->add_double_val(elts.getSplatValue<double>());
|
||||
} else {
|
||||
for (auto value : elts.getValues<double>())
|
||||
output_tensor->add_double_val(value);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||
// with the float_val field updated.
|
||||
Status ConvertFloatElementsAttr(const ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
if (elts.isSplat()) {
|
||||
output_tensor->add_float_val(elts.getSplatValue<float>());
|
||||
} else {
|
||||
for (auto value : elts.getValues<float>())
|
||||
output_tensor->add_float_val(value);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||
// with the half_val field updated.
|
||||
Status ConvertHalfElementsAttr(const ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
if (elts.isSplat()) {
|
||||
output_tensor->add_half_val(
|
||||
(*elts.begin()).bitcastToAPInt().getSExtValue());
|
||||
} else {
|
||||
for (const auto& value : elts.getFloatValues())
|
||||
output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||
// with the int_val field updated.
|
||||
Status ConvertIntElementsAttr(const mlir::ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
|
||||
if (elts.isSplat()) {
|
||||
output_tensor->add_int_val((*elts.begin()).getSExtValue());
|
||||
} else {
|
||||
for (const auto& val : elts)
|
||||
output_tensor->add_int_val(val.getSExtValue());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
auto elts = attr.dyn_cast<DenseFPElementsAttr>();
|
||||
if (!elts) {
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
// Bfloat16 is internally represented as `double` in MLIR.
|
||||
if (elts.isSplat()) {
|
||||
double v = elts.getSplatValue<double>();
|
||||
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
||||
output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
|
||||
// Converts an MLIR elements attribute and adds it to specified repeated field.
|
||||
template <typename T>
|
||||
void ConvertElementsAttr(const mlir::DenseElementsAttr attr,
|
||||
protobuf::RepeatedField<T>* output) {
|
||||
if (attr.isSplat()) {
|
||||
output->Add(attr.getSplatValue<T>());
|
||||
} else {
|
||||
for (auto v : elts.getValues<double>()) {
|
||||
for (auto value : attr.getValues<T>()) output->Add(value);
|
||||
}
|
||||
}
|
||||
|
||||
// Converts an MLIR elements attribute containing half values and adds it to
|
||||
// specified repeated field.
|
||||
void ConvertHalfElementsAttr(const DenseFPElementsAttr attr,
|
||||
protobuf::RepeatedField<int>* output_tensor) {
|
||||
if (attr.isSplat()) {
|
||||
output_tensor->Add((*attr.begin()).bitcastToAPInt().getSExtValue());
|
||||
} else {
|
||||
for (const llvm::APFloat value : attr.getFloatValues())
|
||||
output_tensor->Add(value.bitcastToAPInt().getSExtValue());
|
||||
}
|
||||
}
|
||||
|
||||
// Converts an MLIR elements attribute containing int values and adds it to
|
||||
// specified repeated field.
|
||||
void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
|
||||
protobuf::RepeatedField<int>* output) {
|
||||
if (attr.isSplat()) {
|
||||
output->Add((*attr.begin()).getSExtValue());
|
||||
} else {
|
||||
for (const llvm::APInt val : attr) output->Add(val.getSExtValue());
|
||||
}
|
||||
}
|
||||
|
||||
void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr,
|
||||
protobuf::RepeatedField<int>* output) {
|
||||
// Bfloat16 is internally represented as `double` in MLIR.
|
||||
if (attr.isSplat()) {
|
||||
double v = attr.getSplatValue<double>();
|
||||
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
||||
output->Add(absl::bit_cast<int16>(bf16_val));
|
||||
} else {
|
||||
for (auto v : attr.getValues<double>()) {
|
||||
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
||||
output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
|
||||
output->Add(absl::bit_cast<int16>(bf16_val));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||
// with the int64_val field updated.
|
||||
Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
|
||||
if (elts.isSplat()) {
|
||||
output_tensor->add_int64_val((*elts.begin()).getSExtValue());
|
||||
} else {
|
||||
for (const auto& val : elts)
|
||||
output_tensor->add_int64_val(val.getSExtValue());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||
// with bool_val field updated.
|
||||
Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
|
||||
for (const auto& val : elts) {
|
||||
output_tensor->add_bool_val(val.getBoolValue());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
Status ConvertToTensorProto(const ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
|
||||
auto type = attr.getType();
|
||||
auto shape = type.getShape();
|
||||
DataType output_dtype;
|
||||
TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
|
||||
output_tensor->set_dtype(output_dtype);
|
||||
ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape());
|
||||
output->set_dtype(output_dtype);
|
||||
ConvertToTensorShapeProto(shape, output->mutable_tensor_shape());
|
||||
|
||||
if (attr.isa<OpaqueElementsAttr>())
|
||||
return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(), output);
|
||||
|
||||
auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
|
||||
if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr");
|
||||
|
||||
switch (output_dtype) {
|
||||
case DT_FLOAT:
|
||||
return ConvertFloatElementsAttr(attr, output_tensor);
|
||||
ConvertElementsAttr<float>(dense_attr, output->mutable_float_val());
|
||||
break;
|
||||
case DT_HALF:
|
||||
// Handles both DenseFPElementsAttr and OpaqueElementsAttr.
|
||||
return ConvertHalfElementsAttr(attr, output_tensor);
|
||||
ConvertHalfElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
|
||||
output->mutable_half_val());
|
||||
break;
|
||||
case DT_DOUBLE:
|
||||
return ConvertDoubleElementsAttr(attr, output_tensor);
|
||||
ConvertElementsAttr(dense_attr, output->mutable_double_val());
|
||||
break;
|
||||
case DT_QUINT8:
|
||||
case DT_UINT8:
|
||||
case DT_INT8:
|
||||
@ -366,20 +325,40 @@ Status ConvertToTensorProto(const ElementsAttr attr,
|
||||
case DT_UINT16:
|
||||
case DT_INT16:
|
||||
case DT_INT32:
|
||||
return ConvertIntElementsAttr(attr, output_tensor);
|
||||
ConvertIntElementsAttr(dense_attr.cast<DenseIntElementsAttr>(),
|
||||
output->mutable_int_val());
|
||||
break;
|
||||
case DT_UINT32:
|
||||
ConvertElementsAttr(dense_attr, output->mutable_uint32_val());
|
||||
break;
|
||||
case DT_UINT64:
|
||||
ConvertElementsAttr(dense_attr, output->mutable_uint64_val());
|
||||
break;
|
||||
case DT_INT64:
|
||||
return ConvertInt64ElementsAttr(attr, output_tensor);
|
||||
ConvertElementsAttr(dense_attr, output->mutable_int64_val());
|
||||
break;
|
||||
case DT_BOOL:
|
||||
return ConvertBoolElementsAttr(attr, output_tensor);
|
||||
ConvertElementsAttr(dense_attr, output->mutable_bool_val());
|
||||
break;
|
||||
case DT_BFLOAT16:
|
||||
return ConvertBfloat16ElementsAttr(attr, output_tensor);
|
||||
ConvertBfloat16ElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
|
||||
output->mutable_half_val());
|
||||
break;
|
||||
case DT_STRING:
|
||||
return ConvertStringElementsAttr(attr.cast<DenseStringElementsAttr>(),
|
||||
output_tensor);
|
||||
ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(),
|
||||
output->mutable_string_val());
|
||||
break;
|
||||
case DT_COMPLEX64:
|
||||
ConvertComplexElementsAttr(dense_attr, output->mutable_scomplex_val());
|
||||
break;
|
||||
case DT_COMPLEX128:
|
||||
ConvertComplexElementsAttr(dense_attr, output->mutable_dcomplex_val());
|
||||
break;
|
||||
default:
|
||||
return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(),
|
||||
output_tensor);
|
||||
return errors::Unimplemented(absl::StrCat("Unimplemented data type ",
|
||||
DataTypeString(output_dtype)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <initializer_list>
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
@ -99,48 +100,74 @@ TEST(ConvertTypeToTensorTypeTest, ConvertStringTensor) {
|
||||
EXPECT_EQ(string_values[3], mlir::StringRef("four"));
|
||||
}
|
||||
|
||||
TEST(ConvertTypeToTensorTypeTest, Convert16BitFloats) {
|
||||
class ConvertTensorTest : public ::testing::Test {
|
||||
protected:
|
||||
template <typename T>
|
||||
void VerifyConversion(std::initializer_list<T> values, DataType dtype,
|
||||
mlir::Type expected_ty) {
|
||||
mlir::Builder b(expected_ty.getContext());
|
||||
Tensor tensor(dtype, TensorShape({static_cast<int64>(values.size())}));
|
||||
tensor.flat<T>().setValues(values);
|
||||
|
||||
auto value_or = ConvertTensor(tensor, &b);
|
||||
TF_ASSERT_OK(value_or.status());
|
||||
auto attr = value_or.ValueOrDie();
|
||||
|
||||
EXPECT_EQ(attr.getType().getElementType(), expected_ty);
|
||||
|
||||
Tensor out;
|
||||
TF_ASSERT_OK(ConvertToTensor(attr, &out));
|
||||
|
||||
test::ExpectTensorEqual<T>(tensor, out);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ConvertTensorTest, Simple) {
|
||||
RegisterDialects();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
mlir::Builder b(&context);
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<Eigen::half>(
|
||||
{Eigen::half(1.0)}, DT_HALF, mlir::FloatType::getF16(&context)));
|
||||
ASSERT_NO_FATAL_FAILURE(
|
||||
VerifyConversion<bfloat16>({bfloat16(1.0), bfloat16(-1.0)}, DT_BFLOAT16,
|
||||
mlir::FloatType::getBF16(&context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<float>(
|
||||
{1.0, -1.0}, DT_FLOAT, mlir::FloatType::getF32(&context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<double>(
|
||||
{1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context)));
|
||||
|
||||
{
|
||||
// Create the sample tensor to convert.
|
||||
Tensor tensor(DT_HALF, TensorShape({1}));
|
||||
auto Tt = tensor.flat<Eigen::half>();
|
||||
Tt.setValues({Eigen::half(1.0)});
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int8>(
|
||||
{1, -1}, DT_INT8, mlir::IntegerType::get(8, &context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int16>(
|
||||
{1, -1}, DT_INT16, mlir::IntegerType::get(16, &context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int32>(
|
||||
{1, -1}, DT_INT32, mlir::IntegerType::get(32, &context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int64>(
|
||||
{1, -1}, DT_INT64, mlir::IntegerType::get(64, &context)));
|
||||
|
||||
auto value_or = ConvertTensor(tensor, &b);
|
||||
TF_EXPECT_OK(value_or.status());
|
||||
auto attr = value_or.ValueOrDie();
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint8>(
|
||||
{1, 2}, DT_UINT8,
|
||||
mlir::IntegerType::get(
|
||||
8, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint16>(
|
||||
{1, 2}, DT_UINT16,
|
||||
mlir::IntegerType::get(
|
||||
16, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint32>(
|
||||
{1, 2}, DT_UINT32,
|
||||
mlir::IntegerType::get(
|
||||
32, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint64>(
|
||||
{1, 2}, DT_UINT64,
|
||||
mlir::IntegerType::get(
|
||||
64, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||
|
||||
EXPECT_TRUE(attr.isa<mlir::DenseFPElementsAttr>());
|
||||
EXPECT_TRUE(attr.getType().getElementType().isF16());
|
||||
|
||||
Tensor out;
|
||||
TF_ASSERT_OK(ConvertToTensor(attr, &out));
|
||||
|
||||
test::ExpectTensorEqual<Eigen::half>(tensor, out);
|
||||
}
|
||||
|
||||
{
|
||||
// Create the sample tensor to convert.
|
||||
Tensor tensor(DT_BFLOAT16, TensorShape({2}));
|
||||
auto Tt = tensor.flat<bfloat16>();
|
||||
Tt.setValues({bfloat16(1.0), bfloat16(-1.0)});
|
||||
|
||||
auto value_or = ConvertTensor(tensor, &b);
|
||||
TF_EXPECT_OK(value_or.status());
|
||||
auto attr = value_or.ValueOrDie();
|
||||
|
||||
EXPECT_TRUE(attr.isa<mlir::DenseFPElementsAttr>());
|
||||
EXPECT_TRUE(attr.getType().getElementType().isBF16());
|
||||
|
||||
Tensor out;
|
||||
TF_ASSERT_OK(ConvertToTensor(attr, &out));
|
||||
|
||||
test::ExpectTensorEqual<bfloat16>(tensor, out);
|
||||
}
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<float>>(
|
||||
{{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX64,
|
||||
mlir::ComplexType::get(mlir::FloatType::getF32(&context))));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<double>>(
|
||||
{{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX128,
|
||||
mlir::ComplexType::get(mlir::FloatType::getF64(&context))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -59,6 +59,18 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
// static TensorFlow op prefix set.
|
||||
std::set<std::string>* GlobalOpPrefixes() {
|
||||
static std::set<std::string>* global_op_prefixes = [] {
|
||||
std::set<std::string>* result = new std::set<std::string>;
|
||||
result->insert("tf.");
|
||||
result->insert("_tf.");
|
||||
result->insert("tf_executor.");
|
||||
return result;
|
||||
}();
|
||||
return global_op_prefixes;
|
||||
}
|
||||
|
||||
// Converts a location to the debug information for the node def.
|
||||
Status ConvertLocation(mlir::Location inst_loc,
|
||||
NodeDef::ExperimentalDebugInfo* debug_info) {
|
||||
@ -268,8 +280,10 @@ StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef op_name) {
|
||||
// - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
|
||||
// don't need to consider ".source"/".Source" because the nodes with this
|
||||
// suffix are skipped by the caller and will not be added to the graph.
|
||||
if (!op_name.consume_front("_tf.") && !op_name.consume_front("tf.") &&
|
||||
!op_name.consume_front("tf_executor.")) {
|
||||
auto prefixes = GlobalOpPrefixes();
|
||||
if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) {
|
||||
return op_name.consume_front(prefix);
|
||||
})) {
|
||||
return errors::FailedPrecondition("op node '", op_name.str(),
|
||||
"' was not a TF op!");
|
||||
}
|
||||
@ -506,4 +520,9 @@ bool IsLegacyCallInstruction(mlir::Operation* inst) {
|
||||
inst->getName().getStringRef().compare("_tf.LegacyCall") == 0;
|
||||
}
|
||||
|
||||
Status AddTensorFlowOpPrefix(std::string prefix) {
|
||||
GlobalOpPrefixes()->insert(prefix);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -34,10 +34,17 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace mlir {
|
||||
class ShapedType;
|
||||
} // namespace mlir
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
// Add custom op prefix for TensorFlow dialects.
|
||||
Status AddTensorFlowOpPrefix(std::string);
|
||||
|
||||
// Maps an MLIR op name in the TensorFlow dialect or the TensorFlow control
|
||||
// dialect back into a TensorFlow valid op name.
|
||||
StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef);
|
||||
|
@ -164,12 +164,19 @@ std::string GetTPUCompilationDevice(Device system_device) {
|
||||
return DeviceNameUtils::ParsedNameToString(system_device);
|
||||
}
|
||||
|
||||
// Finds the host CPU device for a given TPU device.
|
||||
std::string GetCPUHostDeviceForTPUDevice(Device tpu_device) {
|
||||
tpu_device.type = DEVICE_CPU;
|
||||
tpu_device.id = 0;
|
||||
return DeviceNameUtils::ParsedNameToString(tpu_device);
|
||||
}
|
||||
|
||||
// Determines execution devices when topology and device assignment are not
|
||||
// defined. This is a special case where a single core computation is replicated
|
||||
// to every core in the mesh. TPU devices are simply added to
|
||||
// `execution_devices` of one replica. `num_replicas` must be 1 or the total
|
||||
// number of TPU devices available, and `num_cores_per_replica` must be 1.
|
||||
StatusOr<ExecutionDevices> GetFullMeshTPUExecutionDeviceAssignment(
|
||||
StatusOr<TPUDevicesAndHosts> GetFullMeshTPUExecutionDeviceAssignment(
|
||||
int num_replicas, int num_cores_per_replica,
|
||||
llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices) {
|
||||
const int num_tasks = tpu_devices.size();
|
||||
@ -185,17 +192,18 @@ StatusOr<ExecutionDevices> GetFullMeshTPUExecutionDeviceAssignment(
|
||||
"'num_cores_per_replica' must be equal to 1, got ",
|
||||
num_cores_per_replica);
|
||||
|
||||
ExecutionDevices execution_devices;
|
||||
execution_devices.reserve(num_replicas);
|
||||
TPUDevicesAndHosts devices_and_hosts;
|
||||
devices_and_hosts.reserve(num_replicas);
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
const int task = i / num_tpus_per_task;
|
||||
const int device = i % num_tpus_per_task;
|
||||
execution_devices.push_back(
|
||||
{tensorflow::DeviceNameUtils::ParsedNameToString(
|
||||
tpu_devices[task][device])});
|
||||
const auto& tpu_device = tpu_devices[task][device];
|
||||
devices_and_hosts.push_back({TPUDeviceAndHost(
|
||||
/*device=*/tensorflow::DeviceNameUtils::ParsedNameToString(tpu_device),
|
||||
/*host=*/GetCPUHostDeviceForTPUDevice(tpu_device))});
|
||||
}
|
||||
|
||||
return execution_devices;
|
||||
return devices_and_hosts;
|
||||
}
|
||||
|
||||
// Helper struct for keeping track of task and device for an associated TPU
|
||||
@ -326,7 +334,7 @@ StatusOr<xla::Array4D<TaskAndDevice>> ParseTopologyAttr(
|
||||
// - number of device coordinates (in tuple 3) match number 'num_replicas' *
|
||||
// 'num_cores_per_replica'
|
||||
// - a TPU device associated with each device coordinate
|
||||
StatusOr<std::pair<ExecutionDevices, xla::DeviceAssignmentProto>>
|
||||
StatusOr<std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>>
|
||||
GetGeneralTPUExecutionDeviceAssignment(
|
||||
int num_replicas, int num_cores_per_replica,
|
||||
llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices,
|
||||
@ -361,9 +369,9 @@ GetGeneralTPUExecutionDeviceAssignment(
|
||||
std::vector<bool> used_device_ids(
|
||||
location_to_id(bound_x - 1, bound_y - 1, bound_z - 1, bound_core - 1),
|
||||
false);
|
||||
ExecutionDevices execution_devices(
|
||||
num_replicas,
|
||||
llvm::SmallVector<std::string, 8>(num_cores_per_replica, ""));
|
||||
TPUDevicesAndHosts devices_and_hosts(
|
||||
num_replicas, llvm::SmallVector<TPUDeviceAndHost, 8>(
|
||||
num_cores_per_replica, TPUDeviceAndHost()));
|
||||
xla::DeviceAssignment device_assignment(num_replicas, num_cores_per_replica);
|
||||
int pos = 0;
|
||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||
@ -393,16 +401,18 @@ GetGeneralTPUExecutionDeviceAssignment(
|
||||
|
||||
used_device_ids[device_id] = true;
|
||||
device_assignment(replica, logical_core) = device_id;
|
||||
execution_devices[replica][logical_core] =
|
||||
DeviceNameUtils::ParsedNameToString(tpu_devices[task][device]);
|
||||
auto& device_and_host = devices_and_hosts[replica][logical_core];
|
||||
const auto& tpu_device = tpu_devices[task][device];
|
||||
device_and_host.device = DeviceNameUtils::ParsedNameToString(tpu_device);
|
||||
device_and_host.host = GetCPUHostDeviceForTPUDevice(tpu_device);
|
||||
}
|
||||
}
|
||||
|
||||
xla::DeviceAssignmentProto device_assignment_proto;
|
||||
TF_RETURN_IF_ERROR(device_assignment.Serialize(&device_assignment_proto));
|
||||
|
||||
return std::pair<ExecutionDevices, xla::DeviceAssignmentProto>(
|
||||
std::move(execution_devices), std::move(device_assignment_proto));
|
||||
return std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>(
|
||||
std::move(devices_and_hosts), std::move(device_assignment_proto));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
@ -30,29 +30,40 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
// TPU devices to be used for execution (e.g. devices for TPUExecute ops). They
|
||||
// are ordered by `num_replicas` followed by `num_cores_per_replica`.
|
||||
using ExecutionDevices =
|
||||
llvm::SmallVector<llvm::SmallVector<std::string, 8>, 8>;
|
||||
// A TPU device for execution alongside its associated host CPU device.
|
||||
struct TPUDeviceAndHost {
|
||||
TPUDeviceAndHost() {}
|
||||
TPUDeviceAndHost(llvm::StringRef device, llvm::StringRef host)
|
||||
: device(device), host(host) {}
|
||||
|
||||
// TPU compilation device, execution devices, and optionally execution device
|
||||
// IDs. Execution device IDs are populated if `topology` and `device_assignment`
|
||||
// are provided.
|
||||
std::string device;
|
||||
std::string host;
|
||||
};
|
||||
|
||||
// TPU devices to be used for execution (e.g. devices for TPUExecute ops) and
|
||||
// their associated host CPU devices (for outside compilation). They are ordered
|
||||
// by `num_replicas` followed by `num_cores_per_replica`.
|
||||
using TPUDevicesAndHosts =
|
||||
llvm::SmallVector<llvm::SmallVector<TPUDeviceAndHost, 8>, 8>;
|
||||
|
||||
// TPU compilation device, execution and associated host devices, and optionally
|
||||
// execution device IDs. Execution device IDs are populated if `topology` and
|
||||
// `device_assignment` are provided.
|
||||
struct TPUDeviceAssignment {
|
||||
TPUDeviceAssignment(llvm::StringRef compilation_device,
|
||||
ExecutionDevices&& execution_devices)
|
||||
TPUDevicesAndHosts&& tpu_devices)
|
||||
: compilation_device(compilation_device),
|
||||
execution_devices(std::move(execution_devices)) {}
|
||||
tpu_devices(std::move(tpu_devices)) {}
|
||||
|
||||
TPUDeviceAssignment(llvm::StringRef compilation_device,
|
||||
ExecutionDevices&& execution_devices,
|
||||
TPUDevicesAndHosts&& tpu_devices,
|
||||
xla::DeviceAssignmentProto&& xla_device_assignment)
|
||||
: compilation_device(compilation_device),
|
||||
execution_devices(std::move(execution_devices)),
|
||||
tpu_devices(std::move(tpu_devices)),
|
||||
xla_device_assignment(std::move(xla_device_assignment)) {}
|
||||
|
||||
std::string compilation_device;
|
||||
ExecutionDevices execution_devices;
|
||||
TPUDevicesAndHosts tpu_devices;
|
||||
llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
|
||||
};
|
||||
|
||||
|
@ -323,30 +323,46 @@ TEST(TPURewriteDeviceUtilTest, ValidFullMeshDeviceAssignment) {
|
||||
|
||||
TF_ASSERT_OK(status_or.status());
|
||||
|
||||
auto& tpu_device_assignment = status_or.ValueOrDie();
|
||||
const auto& tpu_device_assignment = status_or.ValueOrDie();
|
||||
EXPECT_EQ(tpu_device_assignment.compilation_device,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
auto& execution_devices = tpu_device_assignment.execution_devices;
|
||||
ASSERT_EQ(execution_devices.size(), 8);
|
||||
for (const auto& replica_execution_device : execution_devices)
|
||||
ASSERT_EQ(replica_execution_device.size(), 1);
|
||||
const auto& tpu_devices = tpu_device_assignment.tpu_devices;
|
||||
ASSERT_EQ(tpu_devices.size(), 8);
|
||||
for (const auto& replica_tpu_devices : tpu_devices)
|
||||
ASSERT_EQ(replica_tpu_devices.size(), 1);
|
||||
|
||||
EXPECT_EQ(execution_devices[0][0],
|
||||
EXPECT_EQ(tpu_devices[0][0].device,
|
||||
"/job:worker/replica:0/task:0/device:TPU:0");
|
||||
EXPECT_EQ(execution_devices[1][0],
|
||||
EXPECT_EQ(tpu_devices[0][0].host,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[1][0].device,
|
||||
"/job:worker/replica:0/task:0/device:TPU:1");
|
||||
EXPECT_EQ(execution_devices[2][0],
|
||||
EXPECT_EQ(tpu_devices[1][0].host,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[2][0].device,
|
||||
"/job:worker/replica:0/task:0/device:TPU:2");
|
||||
EXPECT_EQ(execution_devices[3][0],
|
||||
EXPECT_EQ(tpu_devices[2][0].host,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[3][0].device,
|
||||
"/job:worker/replica:0/task:0/device:TPU:3");
|
||||
EXPECT_EQ(execution_devices[4][0],
|
||||
EXPECT_EQ(tpu_devices[3][0].host,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[4][0].device,
|
||||
"/job:worker/replica:0/task:1/device:TPU:0");
|
||||
EXPECT_EQ(execution_devices[5][0],
|
||||
EXPECT_EQ(tpu_devices[4][0].host,
|
||||
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[5][0].device,
|
||||
"/job:worker/replica:0/task:1/device:TPU:1");
|
||||
EXPECT_EQ(execution_devices[6][0],
|
||||
EXPECT_EQ(tpu_devices[5][0].host,
|
||||
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[6][0].device,
|
||||
"/job:worker/replica:0/task:1/device:TPU:2");
|
||||
EXPECT_EQ(execution_devices[7][0],
|
||||
EXPECT_EQ(tpu_devices[6][0].host,
|
||||
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[7][0].device,
|
||||
"/job:worker/replica:0/task:1/device:TPU:3");
|
||||
EXPECT_EQ(tpu_devices[7][0].host,
|
||||
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||
|
||||
EXPECT_FALSE(tpu_device_assignment.xla_device_assignment.hasValue());
|
||||
}
|
||||
@ -410,30 +426,46 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) {
|
||||
|
||||
TF_ASSERT_OK(status_or.status());
|
||||
|
||||
auto& tpu_device_assignment = status_or.ValueOrDie();
|
||||
const auto& tpu_device_assignment = status_or.ValueOrDie();
|
||||
EXPECT_EQ(tpu_device_assignment.compilation_device,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
auto& execution_devices = tpu_device_assignment.execution_devices;
|
||||
ASSERT_EQ(execution_devices.size(), 4);
|
||||
for (const auto& replica_execution_device : execution_devices)
|
||||
ASSERT_EQ(replica_execution_device.size(), 2);
|
||||
const auto& tpu_devices = tpu_device_assignment.tpu_devices;
|
||||
ASSERT_EQ(tpu_devices.size(), 4);
|
||||
for (const auto& replica_tpu_devices : tpu_devices)
|
||||
ASSERT_EQ(replica_tpu_devices.size(), 2);
|
||||
|
||||
EXPECT_EQ(execution_devices[0][0],
|
||||
EXPECT_EQ(tpu_devices[0][0].device,
|
||||
"/job:worker/replica:0/task:0/device:TPU:0");
|
||||
EXPECT_EQ(execution_devices[0][1],
|
||||
EXPECT_EQ(tpu_devices[0][0].host,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[0][1].device,
|
||||
"/job:worker/replica:0/task:1/device:TPU:3");
|
||||
EXPECT_EQ(execution_devices[1][0],
|
||||
EXPECT_EQ(tpu_devices[0][1].host,
|
||||
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[1][0].device,
|
||||
"/job:worker/replica:0/task:0/device:TPU:1");
|
||||
EXPECT_EQ(execution_devices[1][1],
|
||||
EXPECT_EQ(tpu_devices[1][0].host,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[1][1].device,
|
||||
"/job:worker/replica:0/task:1/device:TPU:2");
|
||||
EXPECT_EQ(execution_devices[2][0],
|
||||
EXPECT_EQ(tpu_devices[1][1].host,
|
||||
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[2][0].device,
|
||||
"/job:worker/replica:0/task:0/device:TPU:3");
|
||||
EXPECT_EQ(execution_devices[2][1],
|
||||
EXPECT_EQ(tpu_devices[2][0].host,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[2][1].device,
|
||||
"/job:worker/replica:0/task:1/device:TPU:0");
|
||||
EXPECT_EQ(execution_devices[3][0],
|
||||
EXPECT_EQ(tpu_devices[2][1].host,
|
||||
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[3][0].device,
|
||||
"/job:worker/replica:0/task:0/device:TPU:2");
|
||||
EXPECT_EQ(execution_devices[3][1],
|
||||
EXPECT_EQ(tpu_devices[3][0].host,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[3][1].device,
|
||||
"/job:worker/replica:0/task:1/device:TPU:1");
|
||||
EXPECT_EQ(tpu_devices[3][1].host,
|
||||
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||
|
||||
auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
|
||||
ASSERT_TRUE(xla_device_assignment.hasValue());
|
||||
@ -511,23 +543,35 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
|
||||
EXPECT_EQ(tpu_device_assignment.compilation_device,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
|
||||
auto& execution_devices = tpu_device_assignment.execution_devices;
|
||||
ASSERT_EQ(execution_devices.size(), 2);
|
||||
for (const auto& replica_execution_device : execution_devices)
|
||||
ASSERT_EQ(replica_execution_device.size(), 3);
|
||||
auto& tpu_devices = tpu_device_assignment.tpu_devices;
|
||||
ASSERT_EQ(tpu_devices.size(), 2);
|
||||
for (const auto& replica_tpu_devices : tpu_devices)
|
||||
ASSERT_EQ(replica_tpu_devices.size(), 3);
|
||||
|
||||
EXPECT_EQ(execution_devices[0][0],
|
||||
EXPECT_EQ(tpu_devices[0][0].device,
|
||||
"/job:worker/replica:0/task:1/device:TPU:1");
|
||||
EXPECT_EQ(execution_devices[0][1],
|
||||
EXPECT_EQ(tpu_devices[0][0].host,
|
||||
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[0][1].device,
|
||||
"/job:worker/replica:0/task:1/device:TPU:0");
|
||||
EXPECT_EQ(execution_devices[0][2],
|
||||
EXPECT_EQ(tpu_devices[0][1].host,
|
||||
"/job:worker/replica:0/task:1/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[0][2].device,
|
||||
"/job:worker/replica:0/task:2/device:TPU:0");
|
||||
EXPECT_EQ(execution_devices[1][0],
|
||||
EXPECT_EQ(tpu_devices[0][2].host,
|
||||
"/job:worker/replica:0/task:2/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[1][0].device,
|
||||
"/job:worker/replica:0/task:2/device:TPU:1");
|
||||
EXPECT_EQ(execution_devices[1][1],
|
||||
EXPECT_EQ(tpu_devices[1][0].host,
|
||||
"/job:worker/replica:0/task:2/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[1][1].device,
|
||||
"/job:worker/replica:0/task:0/device:TPU:0");
|
||||
EXPECT_EQ(execution_devices[1][2],
|
||||
EXPECT_EQ(tpu_devices[1][1].host,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
EXPECT_EQ(tpu_devices[1][2].device,
|
||||
"/job:worker/replica:0/task:0/device:TPU:1");
|
||||
EXPECT_EQ(tpu_devices[1][2].host,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
|
||||
auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
|
||||
ASSERT_TRUE(xla_device_assignment.hasValue());
|
||||
|
@ -104,26 +104,24 @@ int main(int argc, char** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> tags = absl::StrSplit(saved_model_tags, ',');
|
||||
std::vector<std::string> exported_names_vector =
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
absl::Span<std::string> exported_names(exported_names_vector);
|
||||
|
||||
if (import_saved_model_object_graph) {
|
||||
std::unordered_set<std::string> tags =
|
||||
absl::StrSplit(saved_model_tags, ',');
|
||||
std::vector<std::string> exported_names =
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
mlir::MLIRContext context;
|
||||
|
||||
auto module = tensorflow::SavedModelObjectGraphToMlirImport(
|
||||
input_filename, tags, absl::Span<std::string>(exported_names),
|
||||
&context);
|
||||
input_filename, tags, exported_names, &context);
|
||||
if (!module) return 1;
|
||||
|
||||
module->print(output->os());
|
||||
} else if (import_saved_model_signature_defs) {
|
||||
std::unordered_set<std::string> tags =
|
||||
absl::StrSplit(saved_model_tags, ',');
|
||||
mlir::MLIRContext context;
|
||||
|
||||
auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
|
||||
input_filename, tags, &context);
|
||||
input_filename, tags, exported_names, &context);
|
||||
if (!module) return 1;
|
||||
|
||||
module->print(output->os());
|
||||
|
@ -1,4 +1,5 @@
|
||||
load("//third_party/mlir:tblgen.bzl", "gentbl")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
@ -39,7 +40,7 @@ gentbl(
|
||||
"ir/tfjs_ops.td",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -131,10 +132,106 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "json_translate_lib",
|
||||
srcs = [
|
||||
"translate/json_translate.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"translate/json_translate.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_js",
|
||||
":tensorflow_js_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:export_utils",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Translation",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_to_tfjs_json",
|
||||
srcs = ["translate/tf_to_tfjs_json.cc"],
|
||||
hdrs = [
|
||||
"translate/tf_to_tfjs_json.h",
|
||||
],
|
||||
deps = [
|
||||
":json_translate_lib",
|
||||
":tfjs_optimize",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "json_translate",
|
||||
deps = [
|
||||
":json_translate_lib",
|
||||
"@llvm-project//mlir:MlirTranslateMain",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "tf_tfjs_translate_main",
|
||||
srcs = [
|
||||
"translate/tf_tfjs_translate.cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf_tfjs_translate",
|
||||
srcs = [":tf_tfjs_translate_main"],
|
||||
deps = [
|
||||
":json_translate_lib",
|
||||
":tensorflow_js_passes",
|
||||
":tf_to_tfjs_json",
|
||||
":tfjs_optimize",
|
||||
"//tensorflow/compiler/mlir:init_mlir",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace tfjs {
|
||||
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#define TFJS_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow.js dialect definitions
|
||||
|
23
tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD
Normal file
23
tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD
Normal file
@ -0,0 +1,23 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = [
|
||||
"pbtxt",
|
||||
],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/tfjs:tf_tfjs_translate",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
78
tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt
Normal file
78
tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt
Normal file
@ -0,0 +1,78 @@
|
||||
# RUN: tf_tfjs_translate %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Mul -o - | FileCheck %s --dump-input-on-failure
|
||||
# Add two tensor<4xi32> inputs and return the result
|
||||
|
||||
node {
|
||||
name: "Add"
|
||||
op: "Add"
|
||||
input: "input0"
|
||||
input: "input1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input0"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Mul"
|
||||
op: "Mul"
|
||||
input: "Add"
|
||||
input: "Add"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 27
|
||||
}
|
||||
|
||||
# CHECK: "name": "input0"
|
||||
# CHECK-NEXT: "op": "Placeholder"
|
||||
# CHECK: "type": "DT_INT32"
|
||||
# CHECK: "name": "input1",
|
||||
# CHECK-NEXT: "op": "Placeholder"
|
||||
# CHECK: "type": "DT_INT32"
|
||||
# CHECK: "name": "Add"
|
||||
# CHECK-NEXT: "op": "AddV2"
|
||||
# CHECK-NEXT: "input":
|
||||
# CHECK-NEXT: "input0"
|
||||
# CHECK-NEXT: "input1"
|
||||
# CHECK: "type": "DT_INT32"
|
||||
# CHECK: "name": "Mul1"
|
||||
# CHECK-NEXT: "op": "Mul"
|
||||
# CHECK-NEXT: "input":
|
||||
# CHECK-NEXT: "Add"
|
||||
# CHECK-NEXT: "Add"
|
||||
# CHECK: "type": "DT_INT32"
|
||||
# CHECK: "name": "Mul"
|
||||
# CHECK-NEXT: "op": "_Retval"
|
||||
# CHECK-NEXT: "input":
|
||||
# CHECK-NEXT: "Mul1"
|
||||
# CHECK: "type": "DT_INT32"
|
||||
# CHECK: "library"
|
||||
# CHECK: "versions"
|
||||
# CHECK: "producer": 27
|
||||
|
175
tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt
Normal file
175
tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt
Normal file
@ -0,0 +1,175 @@
|
||||
# RUN: tf_tfjs_translate %s -tf-input-arrays=input0 -tf-input-data-types=DT_FLOAT -tf-input-shapes=10 -tf-output-arrays=Add -tf-custom-opdefs="name: 'Prelu' input_arg: { name: 'x' type: DT_FLOAT } input_arg: { name: 'alpha' type: DT_FLOAT } output_arg: { name: 'c' type: DT_FLOAT }" -o - | FileCheck %s --dump-input-on-failure
|
||||
# Add two tensor<4xi32> inputs and return the result
|
||||
|
||||
node {
|
||||
name: "input0"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "alpha"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 0.5
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Relu"
|
||||
op: "Relu"
|
||||
input: "input0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Neg"
|
||||
op: "Neg"
|
||||
input: "input0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Relu1"
|
||||
op: "Relu"
|
||||
input: "Neg"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Mul"
|
||||
op: "Mul"
|
||||
input: "alpha"
|
||||
input: "Relu1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Add"
|
||||
op: "Add"
|
||||
input: "Relu"
|
||||
input: "Mul"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "main"
|
||||
op: "_Retval"
|
||||
input: "Add"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
}
|
||||
versions {
|
||||
producer: 344
|
||||
}
|
||||
|
||||
# CHECK: "node":
|
||||
# CHECK: "name": "input0",
|
||||
# CHECK-NEXT: "op": "Placeholder",
|
||||
# CHECK-NEXT: "attr":
|
||||
# CHECK: "type": "DT_FLOAT"
|
||||
# CHECK: "name": "Add.Relu.Neg.Relu1.Mul",
|
||||
# CHECK-NEXT: "op": "Const",
|
||||
# CHECK-NEXT: "attr":
|
||||
# CHECK: "value":
|
||||
# CHECK: "tensor":
|
||||
# CHECK: "dtype": "DT_FLOAT",
|
||||
# CHECK: "tensorShape": {},
|
||||
# CHECK: "floatVal":
|
||||
# CHECK: -0.5
|
||||
# CHECK: "name": "Add.Relu.Neg.Relu1.Mul1",
|
||||
# CHECK-NEXT: "op": "Prelu",
|
||||
# CHECK-NEXT: "input":
|
||||
# CHECK: "input0",
|
||||
# CHECK: "Add.Relu.Neg.Relu1.Mul"
|
||||
# CHECK: "attr":
|
||||
# CHECK: "_output_shapes":
|
||||
# CHECK: "list":
|
||||
# CHECK: "shape":
|
||||
# CHECK: "dim":
|
||||
# CHECK: "size": "10"
|
||||
# CHECK: "experimentalDebugInfo": {}
|
||||
# CHECK: "name": "Add",
|
||||
# CHECK-NEXT: "op": "_Retval",
|
||||
# CHECK-NEXT: "input":
|
||||
# CHECK: "Add.Relu.Neg.Relu1.Mul1"
|
||||
# CHECK: "attr":
|
||||
# CHECK: "T":
|
||||
# CHECK: "type": "DT_FLOAT"
|
||||
# CHECK: "library": {},
|
||||
# CHECK: "versions":
|
||||
# CHECK: "producer": 344
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
|
||||
|
||||
@ -47,6 +46,11 @@ void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {
|
||||
// Canonicalize, CSE etc.
|
||||
pm->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pm->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
|
||||
|
||||
// raise to executor dialect in order to use GraphDef converter
|
||||
pm->addNestedPass<mlir::FuncOp>(
|
||||
mlir::CreateFunctionalToExecutorDialectConversionPass());
|
||||
pm->addNestedPass<mlir::FuncOp>(mlir::CreateBreakUpIslandsPass());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
105
tensorflow/compiler/mlir/tfjs/translate/json_translate.cc
Normal file
105
tensorflow/compiler/mlir/tfjs/translate/json_translate.cc
Normal file
@ -0,0 +1,105 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Translation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
using mlir::ModuleOp;
|
||||
using mlir::TranslateFromMLIRRegistration;
|
||||
using std::string;
|
||||
using tensorflow::Status;
|
||||
using xla::StatusOr;
|
||||
|
||||
// Translates the given MLIR module in the TFJS dialect to TFJS JSON
|
||||
// format. Returns false on success.
|
||||
//
|
||||
bool tfjs::MlirToJSONTranslateFunction(ModuleOp module,
|
||||
std::string* serialized_json) {
|
||||
string json_output;
|
||||
// Allow TF to treat TFJS ops as TF ops.
|
||||
if (!tensorflow::AddTensorFlowOpPrefix("tfjs.").ok()) {
|
||||
LOG(ERROR) << "Failed to add tfjs op prefix.";
|
||||
return false;
|
||||
}
|
||||
tensorflow::GraphExportConfig confs;
|
||||
confs.export_shapes = true;
|
||||
confs.export_library = true;
|
||||
tensorflow::FunctionLibraryDefinition flib_def(
|
||||
tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
|
||||
absl::flat_hash_set<tensorflow::Node*> control_ret_nodes;
|
||||
auto graph = absl::make_unique<tensorflow::Graph>(flib_def);
|
||||
auto status = tensorflow::ConvertMlirToGraph(module, confs, &graph, &flib_def,
|
||||
&control_ret_nodes);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Graph export failed: " << status;
|
||||
return false;
|
||||
}
|
||||
auto graphdef = absl::make_unique<tensorflow::GraphDef>();
|
||||
graph->ToGraphDef(graphdef.get());
|
||||
|
||||
// Replace the _Arg nodes of the main function with Placeholder op.
|
||||
auto nodes = graphdef->mutable_node();
|
||||
for (const auto& node : llvm::enumerate(*nodes)) {
|
||||
if (node.value().op() == "_Arg") {
|
||||
nodes->Mutable(node.index())->set_op("Placeholder");
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::protobuf::util::JsonPrintOptions json_options;
|
||||
json_options.add_whitespace = true;
|
||||
auto jsonStatus = tensorflow::protobuf::util::MessageToJsonString(
|
||||
*graphdef, &json_output, json_options);
|
||||
if (!jsonStatus.ok()) {
|
||||
LOG(ERROR) << "Proto2Json failed: " << status;
|
||||
return false;
|
||||
}
|
||||
*serialized_json = std::move(json_output);
|
||||
return true;
|
||||
}
|
||||
|
||||
static mlir::LogicalResult MlirToJSONFileTranslateFunction(
|
||||
ModuleOp module, llvm::raw_ostream& output) {
|
||||
std::string serialized_json;
|
||||
if (!tfjs::MlirToJSONTranslateFunction(module, &serialized_json))
|
||||
return mlir::failure();
|
||||
|
||||
output << serialized_json;
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
static TranslateFromMLIRRegistration MLIRToJSONFileTranslate(
|
||||
"mlir-to-tfjs-json", MlirToJSONFileTranslateFunction);
|
31
tensorflow/compiler/mlir/tfjs/translate/json_translate.h
Normal file
31
tensorflow/compiler/mlir/tfjs/translate/json_translate.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tfjs {
|
||||
|
||||
// Translates the given MLIR `module` into a JSON string. Returns true if
|
||||
// translation fails, otherwise returns false.
|
||||
bool MlirToJSONTranslateFunction(mlir::ModuleOp module,
|
||||
std::string* serialized_json);
|
||||
} // namespace tfjs
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
|
173
tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc
Normal file
173
tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc
Normal file
@ -0,0 +1,173 @@
|
||||
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
|
||||
#include "tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h"
|
||||
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
using llvm::cl::opt;
|
||||
using mlir::MLIRContext;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> input_file_name(llvm::cl::Positional,
|
||||
llvm::cl::desc("<input file>"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> import_saved_model_object_graph(
|
||||
"savedmodel-objectgraph-to-mlir",
|
||||
llvm::cl::desc("Import a saved model to its MLIR representation"),
|
||||
llvm::cl::value_desc("dir"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> import_saved_model_signature_defs(
|
||||
"savedmodel-signaturedefs-to-mlir",
|
||||
llvm::cl::desc("Import a saved model V1 to its MLIR representation"),
|
||||
llvm::cl::value_desc("dir"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> saved_model_tags(
|
||||
"tf-savedmodel-tags",
|
||||
llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, "
|
||||
"separated by ','"),
|
||||
llvm::cl::init("serve"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> saved_model_exported_names(
|
||||
"tf-savedmodel-exported-names",
|
||||
llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty "
|
||||
"(the default) means export all."),
|
||||
llvm::cl::init(""));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
opt<std::string> output_file_name("o", llvm::cl::desc("<output file>"),
|
||||
llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("-"));
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> input_mlir(
|
||||
"input-mlir",
|
||||
llvm::cl::desc("Take input TensorFlow model in textual MLIR instead of "
|
||||
"GraphDef format"),
|
||||
llvm::cl::init(false), llvm::cl::Hidden);
|
||||
// NOLINTNEXTLINE
|
||||
opt<bool> output_mlir(
|
||||
"output-mlir",
|
||||
llvm::cl::desc("Output MLIR rather than JSON for the generated TFJS model"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// The following approach allows injecting opdefs in addition
|
||||
// to those that are already part of the global TF registry to be linked in
|
||||
// prior to importing the graph. The primary goal is for support of custom ops.
|
||||
// This is not intended to be a general solution for custom ops for the future
|
||||
// but mainly for supporting older models like mobilenet_ssd. More appropriate
|
||||
// mechanisms, such as op hints or using functions to represent composable ops
|
||||
// like https://github.com/tensorflow/community/pull/113 should be encouraged
|
||||
// going forward.
|
||||
// NOLINTNEXTLINE
|
||||
llvm::cl::list<std::string> custom_opdefs(
|
||||
"tf-custom-opdefs", llvm::cl::desc("List of custom opdefs when importing "
|
||||
"graphdef"));
|
||||
|
||||
// Debugging flag to print function mapping in the JSON.
|
||||
// NOLINTNEXTLINE
|
||||
static opt<bool> print_function_result_mapping(
|
||||
"print-function-result-mapping",
|
||||
llvm::cl::desc(
|
||||
"Print the mapping of function result to json output buffer"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
enum TranslationStatus { kTrSuccess, kTrFailure };
|
||||
|
||||
static int PrintFunctionResultMapping(const std::string& result) {
|
||||
std::cout << result << std::endl;
|
||||
return kTrSuccess;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
tensorflow::InitMlir y(&argc, &argv);
|
||||
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv,
|
||||
"TF GraphDef to TFJS JSON converter\n");
|
||||
|
||||
MLIRContext context;
|
||||
llvm::SourceMgr source_mgr;
|
||||
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> module;
|
||||
|
||||
if (import_saved_model_object_graph || import_saved_model_signature_defs) {
|
||||
if (input_mlir)
|
||||
module = tensorflow::errors::InvalidArgument(
|
||||
"Importing saved model should not have input_mlir set");
|
||||
module = tensorflow::ImportSavedModel(
|
||||
import_saved_model_object_graph, import_saved_model_signature_defs,
|
||||
custom_opdefs, input_file_name, saved_model_tags,
|
||||
saved_model_exported_names, &context);
|
||||
} else {
|
||||
module = tensorflow::LoadFromGraphdefOrMlirSource(
|
||||
input_file_name, input_mlir, custom_opdefs, debug_info_file,
|
||||
input_arrays, input_dtypes, input_shapes, output_arrays,
|
||||
/*prune_unused_nodes=*/true, &source_mgr, &context);
|
||||
}
|
||||
|
||||
// If errors occur, the library call in the above already logged the error
|
||||
// message. So we can just return here.
|
||||
if (!module.ok()) return kTrFailure;
|
||||
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
tensorflow::AddTFToTFJSConversionPasses(&pm);
|
||||
|
||||
std::string result;
|
||||
auto status = tensorflow::ConvertTFOpsToTfjsJSON(module.ValueOrDie().get(),
|
||||
output_mlir, &result, &pm);
|
||||
if (!status.ok()) return kTrFailure;
|
||||
|
||||
std::string error_msg;
|
||||
auto output = mlir::openOutputFile(output_file_name, &error_msg);
|
||||
if (output == nullptr) {
|
||||
llvm::errs() << error_msg << '\n';
|
||||
return kTrFailure;
|
||||
}
|
||||
output->os() << result;
|
||||
output->keep();
|
||||
|
||||
// Print out debugging info related to function mapping.
|
||||
if (print_function_result_mapping) return PrintFunctionResultMapping(result);
|
||||
return kTrSuccess;
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user