Merge branch 'upstream/master' into interface_16x8

Change-Id: Ieb5783a0f182c92f003e7ae53da87ffbc2d62035
This commit is contained in:
Elena Zhelezina 2020-06-08 17:03:06 +01:00
commit 9573987e54
1310 changed files with 44129 additions and 12900 deletions

View File

@ -386,32 +386,32 @@ build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"

View File

@ -1,7 +1,11 @@
# TensorFlow Code of Conduct
In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation.
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and our
community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, gender identity and expression, level of
experience, nationality, personal appearance, race, religion, or sexual identity
and orientation.
## Our Standards

View File

@ -4,26 +4,31 @@ https://stackoverflow.com/questions/tagged/tensorflow
If you open a GitHub issue, here is our policy:
1. It must be a bug, a feature request, or a significant problem with documentation (for small docs fixes please send a PR instead).
2. The form below must be filled out.
3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorboard/issues).
1. It must be a bug, a feature request, or a significant problem with the
documentation (for small docs fixes please send a PR instead).
2. The form below must be filled out.
3. It shouldn't be a TensorBoard issue. Those go
[here](https://github.com/tensorflow/tensorboard/issues).
**Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.
------------------------
### System information
- **Have I written custom code (as opposed to using a stock example script provided in TensorFlow)**:
- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**:
- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device**:
- **TensorFlow installed from (source or binary)**:
- **TensorFlow version (use command below)**:
- **Python version**:
- **Bazel version (if compiling from source)**:
- **GCC/Compiler version (if compiling from source)**:
- **CUDA/cuDNN version**:
- **GPU model and memory**:
- **Exact command to reproduce**:
- **Have I written custom code (as opposed to using a stock example script
provided in TensorFlow)**:
- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**:
- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue
happens on a mobile device**:
- **TensorFlow installed from (source or binary)**:
- **TensorFlow version (use command below)**:
- **Python version**:
- **Bazel version (if compiling from source)**:
- **GCC/Compiler version (if compiling from source)**:
- **CUDA/cuDNN version**:
- **GPU model and memory**:
- **Exact command to reproduce**:
You can collect some of this information using our environment capture script:

View File

@ -90,89 +90,150 @@ Coinciding with this change, new releases of [TensorFlow's Docker images](https:
* The current TensorFlow release now **requires** [gast](https://pypi.org/project/gast/) version 0.3.3.
## Bug Fixes and Other Changes
* `tf.data`:
* Removed `autotune_algorithm` from experimental optimization options.
* TF Core:
* `tf.constant` always creates CPU tensors irrespective of the current device context.
* Eager `TensorHandles` maintain a list of mirrors for any copies to local or remote devices. This avoids any redundant copies due to op execution.
* For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer experimental and is available as simply `.ref()`.
* `pfor/vectorized_map`: Added support for vectorizing 56 more ops. Vectorizing `tf.cond` is also supported now.
* Set as much partial shape as we can infer statically within the gradient impl of the gather op.
* Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body functions are stateless. This allows multiple gradients while ops to run in parallel under distribution strategy.
* Speed up `GradientTape` in eager mode by auto-generating list of op inputs/outputs which are unused and hence not cached for gradient functions.
* Support `back_prop=False` in `while_v2` but mark it as deprecated.
* Improve error message when attempting to use `None` in data-dependent control flow.
* Add `RaggedTensor.numpy()`.
* Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow indexing into uniform dimensions.
* Update `tf.expand_dims` to always insert the new dimension as a non-ragged dimension.
* Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm` when `ids` is ragged.
* Allow `batch_dims==rank(indices)` in `tf.gather`.
* Add support for bfloat16 in `tf.print`.
* `tf.distribute`:
* Support `embedding_column` with variable-length input features for `MultiWorkerMirroredStrategy`.
* `tf.keras`:
* Added `experimental_aggregate_gradients` argument to `tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom gradient aggregation and processing aggregated gradients in custom training loop.
* Allow `pathlib.Path` paths for loading models via Keras API.
* `tf.function`/AutoGraph:
* AutoGraph is now available in `ReplicaContext.merge_call`, `Strategy.extended.update` and `Strategy.extended.update_non_slot`.
* Experimental support for shape invariants has been enabled in `tf.function`. See the API docs for `tf.autograph.experimental.set_loop_options` for additonal info.
* AutoGraph error messages now exclude frames corresponding to APIs internal to AutoGraph.
* Improve shape inference for `tf.function` input arguments to unlock more Grappler optimizations in TensorFlow 2.x.
* Improve automatic control dependency management of resources by allowing resource reads to occur in parallel and synchronizing only on writes.
* Fix execution order of multiple stateful calls to `experimental_run_v2` in `tf.function`.
* You can now iterate over `RaggedTensors` using a for loop inside `tf.function`.
* `tf.lite`:
* Migrated the `tf.lite` C inference API out of experimental into lite/c.
* Add an option to disallow `NNAPI` CPU / partial acceleration on Android 10
* TFLite Android AARs now include the C headers and APIs are required to use TFLite from native code.
* Refactors the delegate and delegate kernel sources to allow usage in the linter.
* Limit delegated ops to actually supported ones if a device name is specified or `NNAPI` CPU Fallback is disabled.
* TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`.
* TFLite's unpack op now supports boolean tensor inputs.
* Microcontroller and embedded code moved from experimental to main TensorFlow Lite folder
* Check for large TFLite tensors.
* Fix GPU delegate crash with C++17.
* Add 5D support to TFLite `strided_slice`.
* Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to be accelerated.
* Fix segmentation fault when running a model with LSTM nodes using `NNAPI` Delegate
* Fix `NNAPI` delegate failure when an operand for Maximum/Minimum operation is a scalar.
* Fix `NNAPI` delegate failure when Axis input for reduce operation is a scalar.
* Expose option to limit the number of partitions that will be delegated to `NNAPI`.
* If a target accelerator is specified, use its feature level to determine operations to delegate instead of SDK version.
* `tf.random`:
* Various random number generation improvements:
* Add a fast path for default `random_uniform`
* `random_seed` documentation improvement.
* `RandomBinomial` broadcasts and appends the sample shape to the left rather than the right.
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`, `tf.random.stateless_poisson`
* `tf.random.stateless_uniform` now supports unbounded sampling of `int` types.
* Math and Linear Algebra:
* Add `tf.linalg.LinearOperatorTridiag`.
* Add `LinearOperatorBlockLowerTriangular`
* Add broadcasting support to tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204), tf.math.invert_permutation.
* Add `tf.math.sobol_sample` op.
* Add `tf.math.xlog1py`.
* Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`.
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to `tf.signal`.
* TPU Enhancements:
* Refactor `TpuClusterResolver` to move shared logic to a separate pip package.
* Support configuring TPU software version from cloud tpu client.
* Allowed TPU embedding weight decay factor to be multiplied by learning rate.
* XLA Support:
* Add standalone XLA AOT runtime target + relevant .cc sources to pip package.
* Add check for memory alignment to MemoryAllocation::MemoryAllocation() on 32-bit ARM. This ensures a deterministic early exit instead of a hard to debug bus error later.
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to XLA header+object files and include them in your C++ programs.
* Enable `Igamma`, `Igammac` for XLA.
* Deterministic Op Functionality:
* XLA reduction emitter is deterministic when the environment variable `TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends deterministic `tf.nn.bias_add` back-prop functionality (and therefore also deterministic back-prop of bias-addition in Keras layers) to include when XLA JIT complilation is enabled.
* Fix problem, when running on a CUDA GPU and when either environment variable `TF_DETERMINSTIC_OPS` or environment variable `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer configurations led to an exception with the message "No algorithm worked!"
* Tracing and Debugging:
* Add source, destination name to `_send` traceme to allow easier debugging.
* Add traceme event to `fastpathexecute`.
* Other:
* Fix an issue with AUC.reset_states for multi-label AUC [#35852](https://github.com/tensorflow/tensorflow/issues/35852)
* Fix the TF upgrade script to not delete files when there is a parsing error and the output mode is `in-place`.
* Move `tensorflow/core:framework/*_pyclif` rules to `tensorflow/core/framework:*_pyclif`.
* `tf.data`:
* Removed `autotune_algorithm` from experimental optimization options.
* TF Core:
* `tf.constant` always creates CPU tensors irrespective of the current
device context.
* Eager `TensorHandles` maintain a list of mirrors for any copies to local
or remote devices. This avoids any redundant copies due to op execution.
* For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer
experimental and is available as simply `.ref()`.
* `pfor/vectorized_map`: Added support for vectorizing 56 more ops.
Vectorizing `tf.cond` is also supported now.
* Set as much partial shape as we can infer statically within the gradient
impl of the gather op.
* Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body
functions are stateless. This allows multiple gradients while ops to run
in parallel under distribution strategy.
* Speed up `GradientTape` in eager mode by auto-generating list of op
inputs/outputs which are unused and hence not cached for gradient
functions.
* Support `back_prop=False` in `while_v2` but mark it as deprecated.
* Improve error message when attempting to use `None` in data-dependent
control flow.
* Add `RaggedTensor.numpy()`.
* Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow
indexing into uniform dimensions.
* Update `tf.expand_dims` to always insert the new dimension as a
non-ragged dimension.
* Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm`
when `ids` is ragged.
* Allow `batch_dims==rank(indices)` in `tf.gather`.
* Add support for bfloat16 in `tf.print`.
* `tf.distribute`:
* Support `embedding_column` with variable-length input features for
`MultiWorkerMirroredStrategy`.
* `tf.keras`:
* Added `experimental_aggregate_gradients` argument to
`tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom
gradient aggregation and processing aggregated gradients in custom
training loop.
* Allow `pathlib.Path` paths for loading models via Keras API.
* `tf.function`/AutoGraph:
* AutoGraph is now available in `ReplicaContext.merge_call`,
`Strategy.extended.update` and `Strategy.extended.update_non_slot`.
* Experimental support for shape invariants has been enabled in
`tf.function`. See the API docs for
`tf.autograph.experimental.set_loop_options` for additonal info.
* AutoGraph error messages now exclude frames corresponding to APIs
internal to AutoGraph.
* Improve shape inference for `tf.function` input arguments to unlock more
Grappler optimizations in TensorFlow 2.x.
* Improve automatic control dependency management of resources by allowing
resource reads to occur in parallel and synchronizing only on writes.
* Fix execution order of multiple stateful calls to `experimental_run_v2`
in `tf.function`.
* You can now iterate over `RaggedTensors` using a for loop inside
`tf.function`.
* `tf.lite`:
* Migrated the `tf.lite` C inference API out of experimental into lite/c.
* Add an option to disallow `NNAPI` CPU / partial acceleration on Android
10
* TFLite Android AARs now include the C headers and APIs are required to
use TFLite from native code.
* Refactors the delegate and delegate kernel sources to allow usage in the
linter.
* Limit delegated ops to actually supported ones if a device name is
specified or `NNAPI` CPU Fallback is disabled.
* TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`.
* TFLite's unpack op now supports boolean tensor inputs.
* Microcontroller and embedded code moved from experimental to main
TensorFlow Lite folder
* Check for large TFLite tensors.
* Fix GPU delegate crash with C++17.
* Add 5D support to TFLite `strided_slice`.
* Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to
be accelerated.
* Fix segmentation fault when running a model with LSTM nodes using
`NNAPI` Delegate
* Fix `NNAPI` delegate failure when an operand for Maximum/Minimum
operation is a scalar.
* Fix `NNAPI` delegate failure when Axis input for reduce operation is a
scalar.
* Expose option to limit the number of partitions that will be delegated
to `NNAPI`.
* If a target accelerator is specified, use its feature level to determine
operations to delegate instead of SDK version.
* `tf.random`:
* Various random number generation improvements:
* Add a fast path for default `random_uniform`
* `random_seed` documentation improvement.
* `RandomBinomial` broadcasts and appends the sample shape to the left
rather than the right.
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`,
`tf.random.stateless_poisson`
* `tf.random.stateless_uniform` now supports unbounded sampling of `int`
types.
* Math and Linear Algebra:
* Add `tf.linalg.LinearOperatorTridiag`.
* Add `LinearOperatorBlockLowerTriangular`
* Add broadcasting support to
tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204),
tf.math.invert_permutation.
* Add `tf.math.sobol_sample` op.
* Add `tf.math.xlog1py`.
* Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`.
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to
`tf.signal`.
* TPU Enhancements:
* Refactor `TpuClusterResolver` to move shared logic to a separate pip
package.
* Support configuring TPU software version from cloud tpu client.
* Allowed TPU embedding weight decay factor to be multiplied by learning
rate.
* XLA Support:
* Add standalone XLA AOT runtime target + relevant .cc sources to pip
package.
* Add check for memory alignment to MemoryAllocation::MemoryAllocation()
on 32-bit ARM. This ensures a deterministic early exit instead of a hard
to debug bus error later.
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to
XLA header+object files and include them in your C++ programs.
* Enable `Igamma`, `Igammac` for XLA.
* Deterministic Op Functionality:
* XLA reduction emitter is deterministic when the environment variable
`TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends
deterministic `tf.nn.bias_add` back-prop functionality (and therefore
also deterministic back-prop of bias-addition in Keras layers) to
include when XLA JIT compilation is enabled.
* Fix problem, when running on a CUDA GPU and when either environment
variable `TF_DETERMINSTIC_OPS` or environment variable
`TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer
configurations led to an exception with the message "No algorithm
worked!"
* Tracing and Debugging:
* Add source, destination name to `_send` traceme to allow easier
debugging.
* Add traceme event to `fastpathexecute`.
* Other:
* Fix an issue with AUC.reset_states for multi-label AUC
[#35852](https://github.com/tensorflow/tensorflow/issues/35852)
* Fix the TF upgrade script to not delete files when there is a parsing
error and the output mode is `in-place`.
* Move `tensorflow/core:framework/*_pyclif` rules to
`tensorflow/core/framework:*_pyclif`.
## Thanks to our Contributors

View File

@ -589,14 +589,16 @@ void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
TF_DeviceList* response = new TF_DeviceList;
status->status = session->session->ListDevices(&response->response);
if (session && session->session)
status->status = session->session->ListDevices(&response->response);
return response;
}
TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
TF_Status* status) {
TF_DeviceList* response = new TF_DeviceList;
status->status = session->session->ListDevices(&response->response);
if (session && session->session)
status->status = session->session->ListDevices(&response->response);
return response;
}
@ -1384,6 +1386,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
cpp_type v; \
status->status = \
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
if (!status->status.ok()) return; \
*value = static_cast<c_type>(v); \
} \
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
@ -2178,6 +2181,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
}
return new_session;
} else {
LOG(ERROR) << status->status;
DCHECK_EQ(nullptr, session);
return nullptr;
}

View File

@ -369,6 +369,7 @@ tf_cuda_cc_test(
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
"noasan", # leaks gRPC server instances
],
deps = [
@ -400,7 +401,10 @@ tf_cuda_cc_test(
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = ["noasan"], # leaks gRPC server instances
tags = [
"no_windows",
"noasan", # leaks gRPC server instances
],
deps = [
":c_api",
":c_api_experimental",
@ -430,7 +434,10 @@ tf_cuda_cc_test(
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = ["noasan"], # leaks gRPC server instances
tags = [
"no_windows",
"noasan", # leaks gRPC server instances
],
deps = [
":c_api",
":c_api_experimental",

View File

@ -1397,23 +1397,17 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return;
}
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function_def);
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
}
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function->fdef);
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function->fdef);
}
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->RemoveFunction(name);
status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
}
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {

View File

@ -104,6 +104,14 @@ class AbstractContextInterface {
// Block until all pending nodes are finished.
virtual Status AsyncWait() = 0;
// Add a function (serialized FunctionDef protocol buffer) so that it can
// be executed as an op. Return error if the function with the same name
// already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Remove a function. 'func' argument is the name of a previously added
// FunctionDef. The name is in fdef.signature.name.
virtual Status RemoveFunction(const string& func) = 0;
protected:
virtual ~AbstractContextInterface() {}
};

View File

@ -12,28 +12,69 @@ package(
# need a second rule that omits .cc files, in
# tensorflow/python:_pywrap_parallel_device.
filegroup(
name = "headers",
name = "lib_headers",
srcs = ["parallel_device_lib.h"],
)
filegroup(
name = "lib_sources",
srcs = ["parallel_device_lib.cc"],
)
filegroup(
name = "device_headers",
srcs = ["parallel_device.h"],
)
filegroup(
name = "device_sources",
srcs = ["parallel_device.cc"],
)
filegroup(
name = "headers",
srcs = [
":device_headers",
":lib_headers",
],
visibility = ["//tensorflow/python:__pkg__"],
)
filegroup(
name = "sources",
srcs = ["parallel_device.cc"],
srcs = [
":device_sources",
":lib_sources",
],
visibility = ["//tensorflow/python:__pkg__"],
)
cc_library(
name = "parallel_device",
srcs = [":sources"],
hdrs = [":headers"],
srcs = [":device_sources"],
hdrs = [":device_headers"],
visibility = ["//tensorflow:internal"],
deps = [
":parallel_device_lib",
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],
)
cc_library(
name = "parallel_device_lib",
srcs = [":lib_sources"],
hdrs = [":lib_headers"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],

View File

@ -23,25 +23,13 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace eager {
namespace parallel_device {
namespace {
// Functor for making unique_ptrs slightly more ergonomic. Using
// decltype(delete_fn) in the unique_ptr's second template argument requires
// passing a function pointer to delete_fn when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) const {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class OpDeleter {
public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
@ -49,224 +37,43 @@ class OpDeleter {
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class ParallelTensor;
using MaybeParallelTensorOwned =
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
using MaybeParallelTensorUnowned =
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
// Creates a vector of `count` new executors (threads).
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
std::vector<ExecutorPtr> executors;
executors.reserve(count);
for (int i = 0; i < count; ++i) {
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
}
return executors;
}
// A representation of the custom device passed in and out of the TFE custom
// device APIs, providing context about the parallel device to
// ParallelDeviceExecute.
class ParallelDevice {
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
// placed on the parallel device.
class NamedParallelDevice {
public:
ParallelDevice(const std::string& name,
const std::vector<std::string>& devices);
// Helper to copy a tensor handle from another device once for each component
// of the ParallelDevice.
//
// Sets a bad status and returns a nullptr if `tensor` is already on the
// ParallelDevice, or if the individual copies fail.
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
// implicitly-mirrored tensors on other devices). Wraps the resulting
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
// output of the original operation.
//
// `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
// un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
// requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
// tensor, but other operations will implicitly broadcast non-parallel input
// tensors across the ParallelDevice's component devices.
//
// Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
// pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
// causes `Execute` to return non-parallel tensors.
//
// Attributes are forwarded to executed operations unmodified.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK.
absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
// Implements the parallel case for `Execute`, where all of the outputs of the
// operation are ParallelTensors, and all inputs are either ParallelTensors or
// should be implicitly broadcast. This means the operation is not
// TPUReplicatedInput or TPUReplicatedOutput.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ExecuteParallelOperation(TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name,
const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
const std::string& device_name() const { return device_name_; }
NamedParallelDevice(const std::string& name,
std::unique_ptr<ParallelDevice> parallel_device)
: device_name_(name), parallel_device_(std::move(parallel_device)) {}
const std::string& name() const { return device_name_; }
const ParallelDevice& device() const { return *parallel_device_; }
private:
// The name of the parallel device
// (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0")
const std::string device_name_;
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of TFE_Executors, one per device, for executing operations in
// parallel.
const std::vector<ExecutorPtr> executors_;
std::string device_name_;
std::unique_ptr<ParallelDevice> parallel_device_;
};
// The internal representation of a TFE_TensorHandle placed on a
// ParallelDevice. Contains a tuple of tensors, one on each of the
// `underlying_devices_` of the ParallelDevice.
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
// Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
static TensorHandlePtr AsTensorHandle(TFE_Context* context,
std::unique_ptr<ParallelTensor> t,
TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
dtype_(dtype) {}
const ParallelDevice& device_;
const std::vector<TensorHandlePtr> tensors_;
const std::vector<int64_t> shape_;
const TF_DataType dtype_;
};
ParallelDevice::ParallelDevice(const std::string& name,
const std::vector<std::string>& devices)
: device_name_(name),
underlying_devices_(devices),
executors_(MakeExecutors(underlying_devices_.size())) {}
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
if (device_name_ == current_device) {
std::string message(absl::StrCat(
"Tried to copy a TensorHandle to its existing device: ", device_name_));
TF_SetStatus(status, TF_INTERNAL, message.c_str());
return nullptr;
}
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (const std::string& underlying_device_name : underlying_devices_) {
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, underlying_device_name.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(t);
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
const ParallelDevice& parallel_device,
const std::string& parallel_device_name, TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int expected_max_outputs,
TF_Status* status) {
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
// TODO(allenl): We should remove "TPU" from these op names at the very least,
// or consider other ways of packing/unpacking parallel tensors.
if (operation_name == std::string("TPUReplicatedInput")) {
// Special-cased operation for packing per-device tensors into one parallel
// tensor.
if (inputs.size() != underlying_devices_.size()) {
if (inputs.size() != parallel_device.num_underlying_devices()) {
std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ",
underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ",
inputs.size()));
"The parallel device ", parallel_device_name, " expected ",
parallel_device.num_underlying_devices(),
" inputs to TPUReplicatedInput, but got ", inputs.size()));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
@ -289,7 +96,7 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
parallel_device, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
@ -300,10 +107,10 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_OpAddAttrs(op.get(), attributes);
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
if (TF_GetCode(status) != TF_OK) return result;
if (expected_outputs != underlying_devices_.size()) {
if (expected_outputs != parallel_device.num_underlying_devices()) {
std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ",
underlying_devices_.size(),
"The parallel device ", parallel_device_name, " expected ",
parallel_device.num_underlying_devices(),
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
@ -319,11 +126,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
std::vector<MaybeParallelTensorOwned> outputs;
outputs.reserve(t->num_tensors());
for (int i = 0; i < t->num_tensors(); ++i) {
// TODO(b/157523095): Syncing the executor here shouldn't be
// necessary. Currently async+remote is missing cross-executor
// coordination.
TFE_ExecutorWaitForAllPendingNodes(executors_[i].get(), status);
if (TF_GetCode(status) != TF_OK) return result;
TensorHandlePtr this_output(
TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
outputs.emplace_back(std::move(this_output));
@ -334,15 +136,15 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
} else if (operation_name == std::string("DeviceID")) {
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(DeviceIDs(context, status));
result_content.push_back(parallel_device.DeviceIDs(context, status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results(
ExecuteParallelOperation(context, std::move(inputs), operation_name,
attributes, expected_max_outputs, status));
parallel_device.Execute(context, std::move(inputs), operation_name,
attributes, expected_max_outputs, status));
if (!maybe_parallel_results.has_value()) return result;
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
std::move(maybe_parallel_results.value()));
@ -356,144 +158,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::ExecuteParallelOperation(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
// setting the thread-local executor like this.
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
TFE_ContextSetExecutorForThread(context, previous_executor);
TFE_DeleteExecutor(previous_executor);
});
int first_op_output_count;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// Note that the `reset_executor` cleanup sets the thread's executor back to
// the value before this function ran.
TFE_ContextSetExecutorForThread(context, executor);
OpPtr op(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return result;
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
status);
TFE_OpAddAttrs(op.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
TFE_OpAddInput(op.get(),
absl::get<TFE_TensorHandle*>(inputs[input_index]),
status);
if (TF_GetCode(status) != TF_OK) return result;
} else {
// Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(),
absl::get<ParallelTensor*>(inputs[input_index])
->tensor(device_index),
status);
if (TF_GetCode(status) != TF_OK) return result;
}
}
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
// For nested devices, the inner device sees the async executor we've
// set. Inner parallel devices will just overwrite this with their own and
// then set it back to ours before returning. This means parallel devices
// which consist of several aliased parallel devices would hypothetically
// deadlock if the outer parallel device ran one collective with a group
// size equal to the total number of aliased physical devices. Currently
// physical devices cannot participate in a single collective reduction
// multiple times, so this would fail earlier.
//
// TODO(allenl): Keep a map from outer executor to list of inner executors
// rather than a single list of executors so aliased nested parallel devices
// don't re-use an executor.
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
if (device_index == 0) {
first_op_output_count = real_num_outputs;
} else {
if (real_num_outputs != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
if (TF_GetCode(status) != TF_OK) return result;
std::vector<TensorHandlePtr> this_outputs;
this_outputs.reserve(real_num_outputs);
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
this_outputs.emplace_back(op_outputs[output_num]);
}
per_device_output_tensors.push_back(std::move(this_outputs));
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
per_device_outputs.reserve(first_op_output_count);
for (int i = 0; i < first_op_output_count; ++i) {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
return result;
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Components of a ParallelTensor must currently all have "
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
}
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
// reference counts drop to zero.
@ -501,17 +165,18 @@ void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<ParallelTensor*>(data);
}
TensorHandlePtr ParallelTensor::AsTensorHandle(
TFE_Context* context, std::unique_ptr<ParallelTensor> t,
TF_Status* status) {
TensorHandlePtr ParallelTensorToTensorHandle(
const std::string& parallel_device_name, TFE_Context* context,
std::unique_ptr<ParallelTensor> t, TF_Status* status) {
// The resulting TensorHandle owns an opaque pointer to "device memory", which
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
// deleted, it will call ParallelTensorDeallocator to free the struct.
ParallelTensor* t_released = t.release();
const std::vector<int64_t>& shape(t_released->shape());
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
context, t_released->device_.device_name().c_str(), t_released->dtype_,
t_released->shape_.data(), t_released->shape_.size(), t_released, 1,
&ParallelTensorDeallocator, nullptr, status));
context, parallel_device_name.c_str(), t_released->dtype(), shape.data(),
shape.size(), t_released, 1, &ParallelTensorDeallocator, nullptr,
status));
}
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
@ -527,12 +192,14 @@ TensorHandlePtr ParallelTensor::AsTensorHandle(
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
const ParallelDevice& dev = named_device->device();
std::unique_ptr<ParallelTensor> parallel_tensor(
dev->CopyToParallelDevice(context, tensor, status));
dev.CopyToParallelDevice(context, tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor),
status)
return ParallelTensorToTensorHandle(named_device->name(), context,
std::move(parallel_tensor), status)
.release();
}
@ -566,14 +233,15 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* status,
void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
std::vector<MaybeParallelTensorUnowned> typed_inputs;
typed_inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
const char* tensor_handle_device =
TFE_TensorHandleDeviceName(inputs[i], status);
if (TF_GetCode(status) != TF_OK) return;
if (dev->device_name() == tensor_handle_device) {
if (named_device->name() == tensor_handle_device) {
// We assume that any tensors already placed on this device are
// ParallelTensors.
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
@ -585,8 +253,9 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
}
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
dev->Execute(context, std::move(typed_inputs), operation_name, attributes,
*num_outputs, status));
ExecuteWithSpecialOps(named_device->device(), named_device->name(),
context, std::move(typed_inputs), operation_name,
attributes, *num_outputs, status));
if (TF_GetCode(status) != TF_OK) return;
if (!maybe_typed_outputs.has_value()) {
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
@ -607,8 +276,8 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
} else {
outputs[i] = ParallelTensor::AsTensorHandle(
context,
outputs[i] = ParallelTensorToTensorHandle(
named_device->name(), context,
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
typed_output)),
status)
@ -625,7 +294,7 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void DeleteParallelDevice(void* device_info) {
delete reinterpret_cast<ParallelDevice*>(device_info);
delete reinterpret_cast<NamedParallelDevice*>(device_info);
}
} // namespace
@ -644,8 +313,10 @@ void AllocateParallelDevice(const char* device_name,
++device_index) {
underlying_devices_vector.push_back(underlying_devices[device_index]);
}
*device_info = new ParallelDevice(device_name, underlying_devices_vector);
std::unique_ptr<ParallelDevice> parallel_device(
new ParallelDevice(underlying_devices_vector));
*device_info =
new NamedParallelDevice{device_name, std::move(parallel_device)};
}
} // namespace eager
} // namespace parallel_device
} // namespace tensorflow

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace eager {
namespace parallel_device {
// Allocate a parallel device named `device_name` which forwards operations to
// `underlying_devices`, maintaining "parallel tensors" with components placed
@ -59,7 +59,7 @@ void AllocateParallelDevice(const char* device_name,
int num_underlying_devices,
TFE_CustomDevice* device, void** device_info);
} // namespace eager
} // namespace parallel_device
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_

View File

@ -0,0 +1,251 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace parallel_device {
namespace {
class OpDeleter {
public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
};
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
// Creates a vector of `count` new executors (threads).
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
std::vector<ExecutorPtr> executors;
executors.reserve(count);
for (int i = 0; i < count; ++i) {
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
}
return executors;
}
} // namespace
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
: underlying_devices_(devices),
executors_(MakeExecutors(underlying_devices_.size())) {}
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (const std::string& underlying_device_name : underlying_devices_) {
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, underlying_device_name.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(t);
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::Execute(TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int expected_max_outputs,
TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
// setting the thread-local executor like this.
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
auto reset_executor =
tensorflow::gtl::MakeCleanup([context, previous_executor]() {
TFE_ContextSetExecutorForThread(context, previous_executor);
TFE_DeleteExecutor(previous_executor);
});
int first_op_output_count;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// Note that the `reset_executor` cleanup sets the thread's executor back to
// the value before this function ran.
TFE_ContextSetExecutorForThread(context, executor);
OpPtr op(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return result;
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
status);
TFE_OpAddAttrs(op.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
TFE_OpAddInput(op.get(),
absl::get<TFE_TensorHandle*>(inputs[input_index]),
status);
if (TF_GetCode(status) != TF_OK) return result;
} else {
// Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(),
absl::get<ParallelTensor*>(inputs[input_index])
->tensor(device_index),
status);
if (TF_GetCode(status) != TF_OK) return result;
}
}
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
// For nested devices, the inner device sees the async executor we've
// set. Inner parallel devices will just overwrite this with their own and
// then set it back to ours before returning. This means parallel devices
// which consist of several aliased parallel devices would hypothetically
// deadlock if the outer parallel device ran one collective with a group
// size equal to the total number of aliased physical devices. Currently
// physical devices cannot participate in a single collective reduction
// multiple times, so this would fail earlier.
//
// TODO(allenl): Keep a map from outer executor to list of inner executors
// rather than a single list of executors so aliased nested parallel devices
// don't re-use an executor.
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
if (device_index == 0) {
first_op_output_count = real_num_outputs;
} else {
if (real_num_outputs != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
if (TF_GetCode(status) != TF_OK) return result;
std::vector<TensorHandlePtr> this_outputs;
this_outputs.reserve(real_num_outputs);
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
this_outputs.emplace_back(op_outputs[output_num]);
}
per_device_output_tensors.push_back(std::move(this_outputs));
}
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// TODO(b/157523095): Syncing the executor here shouldn't be
// necessary. Currently async+remote is missing cross-executor
// coordination.
TFE_ExecutorWaitForAllPendingNodes(executor, status);
if (TF_GetCode(status) != TF_OK) return result;
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
per_device_outputs.reserve(first_op_output_count);
for (int i = 0; i < first_op_output_count; ++i) {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
return result;
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Components of a ParallelTensor must currently all have "
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -0,0 +1,141 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
#include <memory>
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace parallel_device {
// Functor for making unique_ptrs slightly more ergonomic. Using
// decltype(delete_fn) in the unique_ptr's second template argument requires
// passing a function pointer to delete_fn when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) const {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class ParallelTensor;
using MaybeParallelTensorUnowned =
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
// Forwards operations to `devices`, maintaining ParallelTensor with components
// placed on each underlying device.
class ParallelDevice {
public:
explicit ParallelDevice(const std::vector<std::string>& devices);
// Helper to copy a tensor handle from another device once for each component
// of the ParallelDevice.
//
// Sets a bad status and returns a nullptr if `tensor` is already on the
// ParallelDevice, or if the individual copies fail.
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;
// The number of devices operations run on.
size_t num_underlying_devices() const { return underlying_devices_.size(); }
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
// implicitly-mirrored tensors on other devices). Wraps the resulting
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
// output of the original operation.
//
// Attributes are forwarded to executed operations unmodified.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
private:
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of TFE_Executors, one per device, for executing operations in
// parallel.
const std::vector<ExecutorPtr> executors_;
};
// Contains a tuple of tensors, one on each of the `underlying_devices_` of the
// ParallelDevice.
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
// A generalization of the shapes of the underlying tensors.
const std::vector<int64_t>& shape() const { return shape_; }
TF_DataType dtype() const { return dtype_; }
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
dtype_(dtype) {}
const ParallelDevice& device_;
const std::vector<TensorHandlePtr> tensors_;
const std::vector<int64_t> shape_;
const TF_DataType dtype_;
};
} // namespace parallel_device
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_

View File

@ -165,7 +165,7 @@ void RegisterParallelDevice(
TF_Status* status) {
TFE_CustomDevice device;
void* device_info;
tensorflow::eager::AllocateParallelDevice(
tensorflow::parallel_device::AllocateParallelDevice(
device_name, underlying_devices.data(), underlying_devices.size(),
&device, &device_info);
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);

View File

@ -67,13 +67,13 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
"@llvm-project//llvm:arm_target", # fixdeps: keep
"@llvm-project//llvm:powerpc_target", # fixdeps: keep
"@llvm-project//llvm:target_base",
"@llvm-project//llvm:x86_target", # fixdeps: keep
"//tensorflow/core:regexp_internal",
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
"@llvm-project//llvm:aarch64_target", # fixdeps: keep
]),
)
@ -95,7 +95,7 @@ tf_cc_test(
"//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support", # fixdeps: keep
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
"@llvm-project//llvm:x86_target", # fixdeps: keep
],
)
@ -109,12 +109,12 @@ cc_library(
name = "llvm_targets",
visibility = ["//tensorflow/python:__pkg__"],
deps = [
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
"@llvm-project//llvm:arm_target", # fixdeps: keep
"@llvm-project//llvm:powerpc_target", # fixdeps: keep
"@llvm-project//llvm:target_base",
"@llvm-project//llvm:x86_target", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
"@llvm-project//llvm:aarch64_target", # fixdeps: keep
]),
)
@ -286,9 +286,9 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:core",
"@llvm-project//llvm:ir",
"@llvm-project//llvm:support",
"@llvm-project//llvm:target",
"@llvm-project//llvm:target_base",
],
)

View File

@ -424,6 +424,10 @@ StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
const std::vector<uint8_t>& buffer,
OpBuilder builder, Location loc) {
if (buffer.empty()) {
return errors::InvalidArgument("Constant's buffer may not be empty");
}
TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
/*shapeless_are_scalars=*/true,
/*is_constant=*/true));
@ -695,8 +699,6 @@ StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
for (int32_t output : output_indices) {
if (auto& op = defining_op[output]) {
queue.push_back(op);
} else {
return errors::InvalidArgument("Output tensor doesn't have defining op");
}
}

View File

@ -46,28 +46,68 @@ namespace mlir {
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
namespace TFL {
// Returns true when the given two types have the same shape or broadcastable
// shape within the given rank. If any given shapes are non-static, this method
// returns true.
bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
int max_bcast_rank) {
// Ignore shape checking on the non-static shapes for model compatibility.
auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
// Returns true when the given operand arguments have the same shape or
// broadcastable shape within the given rank. If any given shapes are
// non-static and maximum rank is within the given rank, this method returns
// true.
bool IsOperandsHaveSameShapesOrBroadcastableShape(Operation *op,
ArrayRef<unsigned> indices,
int max_bcast_rank) {
if (indices.empty()) return true;
if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
return true;
// First, it checks there are any inputs that has unknown rank.
bool has_unknown_shape_input = false;
bool has_same_shape = true;
bool reach_first_known_shape = false;
int64_t max_rank = -1;
ArrayRef<int64_t> pivot_shape;
SmallVector<int64_t, 4> current_shape;
SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
rhs_shaped_type.getShape(),
result_shape)) {
return false;
for (unsigned index : indices) {
ShapedType shaped_type =
op->getOperand(index).getType().dyn_cast<ShapedType>();
if (!shaped_type || !shaped_type.hasRank()) {
// Marks that we have an unknown rank input.
has_unknown_shape_input = true;
continue;
}
max_rank = std::max(max_rank, shaped_type.getRank());
if (!shaped_type.hasStaticShape()) {
// Marks that we have an unknown shape input.
has_unknown_shape_input = true;
continue;
}
ArrayRef<int64_t> shape = shaped_type.getShape();
if (!reach_first_known_shape) {
pivot_shape = shape;
current_shape.assign(shape.begin(), shape.end());
reach_first_known_shape = true;
continue;
}
if (!pivot_shape.equals(shape)) {
has_same_shape = false;
}
// Checks if all the inputs are broadcastable since they have not all the
// same shapes.
if (!OpTrait::util::getBroadcastedShape(current_shape, shape,
result_shape)) {
return false;
}
current_shape = result_shape;
}
return lhs_shaped_type.getRank() <= max_bcast_rank &&
rhs_shaped_type.getRank() <= max_bcast_rank;
// It will treat the unknown shape inputs as acceptable inputs for model
// compatibility unless there is an known rank that is bigger than the allowed
// broadcast maximum rank.
if (has_unknown_shape_input) return max_rank <= max_bcast_rank;
// If all the shape is known and same, CPU kernels are able to handle inputs
// regardless of dimension size.
return has_same_shape || max_rank <= max_bcast_rank;
}
//===----------------------------------------------------------------------===//
@ -1882,7 +1922,7 @@ static LogicalResult Verify(TransposeConvOp op) {
auto expected_output_type =
RankedTensorType::get(output_shape, output_type.getElementType());
if (output_type != expected_output_type) {
if (failed(mlir::verifyCompatibleShape(output_type, expected_output_type))) {
return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
expected_output_type, output_type));
}
@ -2004,7 +2044,8 @@ static LogicalResult Verify(TransposeOp op) {
}
auto expected_output_type =
RankedTensorType::get(transposed_shape, input_type.getElementType());
if (output_type != expected_output_type) {
if (failed(
mlir::verifyCompatibleShape(output_type, expected_output_type))) {
return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
expected_output_type, output_type));
}

View File

@ -123,14 +123,13 @@ class TFL_RuntimePredOpTrait<string desc, Pred pred> :
string tflRuntimeDescription = desc;
}
class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<
int i, int j, int max_bcast_rank> :
TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j #
" have the same shape or broadcastable shapes within the rank " #
max_bcast_rank,
CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape("
"$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j #
").getType(), " # max_bcast_rank # ")">>;
class TFL_OperandsHaveSameShapesOrBroadcastableShape<
list<int> indices, int max_bcast_rank> :
TFL_RuntimePredOpTrait<"operands do not have the same shape or "
"broadcastable shapes within the rank " # max_bcast_rank,
CPred<"TFL::IsOperandsHaveSameShapesOrBroadcastableShape("
"$_op, llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result #
"}), " # max_bcast_rank # ")">>;
// These additional types/type constraints here are used to decouple the ops
// from runtime support for the ops. Prefer to use these types when defining
@ -213,11 +212,20 @@ class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() > "
# dim>]>;
// Returns true if the n-th operand is ranked and has a dimension length = size
// at the rank dim.
class TFL_OperandDimEquals<int n, int dim, int size> : And<[
TFL_OperandIsRankedAndHasDimPred<n, dim>,
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>()"
".getShape()[" # dim # " ] == " # size>]>;
// Returns true if the n-th operand is ranked and has a dimension length <=
// size at the rank dim.
class TFL_OperandDimIsAtMost<int n, int dim, int size> : And<[
TFL_OperandIsRankedAndHasDimPred<n, dim>,
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>()"
".getShape()[" # dim # " ] <= " # size>]>;
// Returns true if the n-th operand has unknown rank or at least rank m.
class TFL_OperandHasAtleastRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
@ -463,6 +471,7 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
//===----------------------------------------------------------------------===//
def TFL_AbsOp : TFL_Op<"abs", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
@ -482,7 +491,7 @@ an output element, this operation computes \\(y = |x|\\).
}
def TFL_AddOp : TFL_Op<"add", [
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
ResultsBroadcastableShape,
NoSideEffect,
Commutative,
@ -669,7 +678,10 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
}]>;
}
def TFL_CeilOp: TFL_Op<"ceil", [NoSideEffect, SameOperandsAndResultType]> {
def TFL_CeilOp: TFL_Op<"ceil", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType]> {
let summary = "Ceil operator";
let description = [{
@ -818,6 +830,7 @@ def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
def TFL_CosOp: TFL_Op<"cos", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
@ -1021,7 +1034,7 @@ def TFL_ScatterNdOp : TFL_Op<"scatter_nd", [
def TFL_LessEqualOp : TFL_Op<"less_equal", [
ResultsBroadcastableShape,
BinaryOpSameElementTypeConstraint,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
NoSideEffect,
NoQuantizableResult]> {
let summary = "Less_equal operator";
@ -1082,7 +1095,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
}
def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
ResultsBroadcastableShape,
NoSideEffect,
NoQuantizableResult]> {
@ -1150,12 +1163,12 @@ innermost matrices. These will be overwritten by the values in `diagonal`.
}];
let arguments = (ins
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$input,
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$diagonal
TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$input,
TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$diagonal
);
let results = (outs
TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$result
TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$result
);
let hasOptions = 0;
@ -1273,7 +1286,7 @@ larger than 0.
}
def TFL_NotEqualOp : TFL_Op<"not_equal", [
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
BinaryOpSameElementTypeConstraint,
ResultsBroadcastableShape,
Commutative,
@ -1309,7 +1322,7 @@ def TFL_DivOp : TFL_Op<"div", [
// TODO(fengliuai): NoQuantizableResult is only correct for int8
// quantization. update to handle Uint8 quantization.
BinaryOpSameElementTypeConstraint,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
ResultsBroadcastableShape,
NoSideEffect,
NoQuantizableResult,
@ -1338,7 +1351,10 @@ def TFL_DivOp : TFL_Op<"div", [
let hasFolder = 1;
}
def TFL_EluOp: TFL_Op<"elu", [NoSideEffect, SameOperandsAndResultType]> {
def TFL_EluOp: TFL_Op<"elu", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType]> {
let summary = "Exponential Linear Unit operator";
let description = [{
Computes the exponential linear
@ -1374,10 +1390,11 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
let results = (outs TFL_TensorOf<[F32, I8, UI8]>:$output);
}
def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape,
def TFL_EqualOp: TFL_Op<"equal", [
Commutative,
NoQuantizableResult,
ResultsBroadcastableShape,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> {
let summary = "Equal operator";
@ -1516,7 +1533,10 @@ def TFL_FillOp: TFL_Op<"fill", [
let hasOptions = 0;
}
def TFL_FloorOp: TFL_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> {
def TFL_FloorOp: TFL_Op<"floor", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType]> {
let summary = "Floor operator";
let description = [{
@ -1534,7 +1554,7 @@ def TFL_FloorDivOp : TFL_Op<"floor_div", [
BinaryOpSameElementTypeConstraint,
PredOpTrait<"lhs and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>]> {
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>]> {
let summary = "Floor div operator";
let description = [{
@ -1559,7 +1579,7 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [
BinaryOpSameElementTypeConstraint,
PredOpTrait<"lhs and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>]> {
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>]> {
let summary = "Division reminder";
let description = [{
@ -1578,7 +1598,7 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [
def TFL_GreaterOp : TFL_Op<"greater", [
ResultsBroadcastableShape,
BinaryOpSameElementTypeConstraint,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
NoSideEffect,
NoQuantizableResult]> {
let summary = "Greater operator";
@ -1670,7 +1690,7 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
def TFL_LessOp : TFL_Op<"less", [
ResultsBroadcastableShape,
BinaryOpSameElementTypeConstraint,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
NoSideEffect,
NoQuantizableResult]> {
let summary = "Less operator";
@ -1710,7 +1730,10 @@ def TFL_LogicalAndOp : TFL_Op<"logical_and", [NoSideEffect]> {
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
}
def TFL_LogicalNotOp : TFL_Op<"logical_not", [NoSideEffect, NoQuantizableResult]> {
def TFL_LogicalNotOp : TFL_Op<"logical_not", [
NoSideEffect,
SameOperandsAndResultShape,
NoQuantizableResult]> {
let summary = "Logical NOT operator";
let description = [{
@ -1794,6 +1817,7 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
def TFL_LogOp: TFL_Op<"log", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
@ -1884,6 +1908,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
def TFL_MaximumOp : TFL_Op<"maximum", [
ResultsBroadcastableShape,
NoSideEffect,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
Commutative,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
@ -2118,6 +2143,7 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [
def TFL_MinimumOp : TFL_Op<"minimum", [
ResultsBroadcastableShape,
NoSideEffect,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
Commutative,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
@ -2145,7 +2171,7 @@ def TFL_MulOp : TFL_Op<"mul", [
NoSideEffect,
Commutative,
BinaryOpSameElementTypeConstraint,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
TFL_GpuTargetOp]> {
let summary = "Multiplication operator";
@ -2171,7 +2197,10 @@ def TFL_MulOp : TFL_Op<"mul", [
let hasOptions = 1;
}
def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
def TFL_NegOp: TFL_Op<"neg", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType]> {
let summary = "Negation operator";
let description = [{
@ -2247,6 +2276,9 @@ def TFL_PadOp : TFL_Op<"pad", [
TFL_OperandHasRankAtMost<0, 4>,
TFL_OperandHasRank<1, 2>,
TFL_OperandRankEquals1DimOfOperand<0, 1>,
PredOpTrait<"the first dim size of the padding argument must be at most 4",
Or<[TFL_OperandIsUnrankedPred<1>,
TFL_OperandDimIsAtMost<1, 0, 4>]>>,
TFL_GpuTargetOp]> {
let summary = "Padding operator";
@ -2292,6 +2324,9 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
TFL_OperandHasRank<1, 2>,
TFL_OperandHasRank<2, 0>,
TFL_OperandRankEquals1DimOfOperand<0, 1>,
PredOpTrait<"the first dim size of the padding argument must be at most 4",
Or<[TFL_OperandIsUnrankedPred<1>,
TFL_OperandDimIsAtMost<1, 0, 4>]>>,
PredOpTrait<"input and constant value operands must have same element type",
TFL_TCopVTEtAreSameAt<0, 2>>]> {
let summary = "Padding operator v2";
@ -2333,10 +2368,12 @@ def TFL_PadV2Op : TFL_Op<"padv2", [
let hasOptions = 1;
}
def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape,
NoSideEffect,
NoQuantizableResult,
TFL_GpuTargetOp]> {
def TFL_PowOp : TFL_Op<"pow", [
ResultsBroadcastableShape,
NoSideEffect,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Power operator";
let description = [{
@ -2360,7 +2397,7 @@ def TFL_PReluOp : TFL_Op<"prelu", [
NoSideEffect,
ResultsBroadcastableShape,
TFL_GpuTargetOp,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
BinaryOpSameElementTypeConstraint,
PredOpTrait<"input and output must have the same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
@ -2671,8 +2708,9 @@ def TFL_SelectOp : TFL_Op<"select", [
}
def TFL_SelectV2Op : TFL_Op<"select_v2", [
ResultsBroadcastableShape,
NoSideEffect,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<1, 2, 4>,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1, 2], 4>,
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
PredOpTrait<"operands and result have same element type",
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
@ -2705,6 +2743,7 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [
def TFL_SinOp: TFL_Op<"sin", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
@ -2752,6 +2791,7 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
def TFL_SqrtOp: TFL_Op<"sqrt", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
@ -2770,6 +2810,7 @@ def TFL_SqrtOp: TFL_Op<"sqrt", [
def TFL_SquareOp: TFL_Op<"square", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
@ -2791,7 +2832,7 @@ def TFL_SquareOp: TFL_Op<"square", [
def TFL_SubOp : TFL_Op<"sub", [
ResultsBroadcastableShape,
BinaryOpSameElementTypeConstraint,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
NoSideEffect]> {
let summary = "Subtraction operator";
@ -2820,7 +2861,7 @@ def TFL_SubOp : TFL_Op<"sub", [
// TODO(jpienaar): Expand the kernel implementation to support all types besides
// I32 and F32.
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
SameOperandsAndResultElementType,
ResultsBroadcastableShape,
NoSideEffect,
@ -3007,6 +3048,8 @@ def TFL_UnpackOp : TFL_Op<"unpack", [
def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
SameOperandsAndResultType,
SameOperandsAndResultShape,
NoSideEffect]> {
let summary = "ZerosLike operator";
@ -3319,7 +3362,9 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
}
def TFL_CastOp : TFL_Op<"cast", [
NoSideEffect, SameOperandsAndResultShape, NoQuantizableResult]> {
NoSideEffect,
SameOperandsAndResultShape,
NoQuantizableResult]> {
let summary = "Cast operator";
let description = [{

View File

@ -0,0 +1,16 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -output-arrays=mul,div,exp --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// CHECK: (%[[ARG:.*]]: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> attributes {tf.entry_function = {inputs = "mul"}} {
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference")
// CHECK: %[[DIV:.*]] = tfl.div
%2 = "tfl.div"(%1, %arg0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
// CHECK: %[[EXP:.*]] = "tfl.exp"
%3 = "tfl.exp"(%2) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
// tfl.neg should be pruned
// CHECK-NOT: "tfl.neg"
%4 = "tfl.neg"(%3) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
// CHECK: return %[[ARG]], %[[DIV]], %[[EXP]]
return %4 : tensor<4xf32>
}

View File

@ -1529,3 +1529,22 @@ func @matmul_batchv2_unknown_dim(%arg0: tensor<?x10x15xf32>, %arg1: tensor<15x17
// CHECK-LABEL: matmul_batchv2_unknown_dim
// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<?x10x15xf32>, tensor<15x17xf32>) -> tensor<?x10x17xf32>
}
// -----
func @select_v2_with_6d_broadcasting(%arg0: tensor<1x1x1x1x3x1xi1>, %arg1 : tensor<1x1x1x1x1x4xf32>, %arg2 : tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2): (tensor<1x1x1x1x3x1xi1>, tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32>
return %0 : tensor<1x1x1x2x3x4xf32>
// CHECK-LABEL: select_v2_with_6d_broadcasting
// CHECK: "tf.SelectV2"(%arg0, %arg1, %arg2)
}
// -----
func @maximum_with_6d_broadcasting(%arg0: tensor<1x1x1x1x8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<1x1x1x1x8x16xf32> {
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<1x1x1x1x8x16xf32>, tensor<8x16xf32>) -> tensor<1x1x1x1x8x16xf32>
return %0 : tensor<1x1x1x1x8x16xf32>
// CHECK-LABEL: maximum_with_6d_broadcasting
// CHECK: "tf.Maximum"(%arg0, %arg1)
}

View File

@ -794,6 +794,41 @@ func @testSelectWithUnsupportedType(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>,
// -----
// CHECK-LABEL: testSelectV2
func @testSelectV2(%cond : tensor<*xi1>, %arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tfl.select_v2"(%arg0, %arg1, %arg2)
%0 = "tfl.select_v2"(%cond, %arg0, %arg1): (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// CHECK-LABEL: testSelectV2WithHighDimInputs
func @testSelectV2WithHighDimInputs(%cond : tensor<1x2x3x4x5x6xi1>, %arg0 : tensor<1x2x3x4x5x6xf32>, %arg1 : tensor<1x2x3x4x5x6xf32>) -> tensor<1x2x3x4x5x6xf32> {
// CHECK: "tfl.select_v2"(%arg0, %arg1, %arg2)
%0 = "tfl.select_v2"(%cond, %arg0, %arg1): (tensor<1x2x3x4x5x6xi1>, tensor<1x2x3x4x5x6xf32>, tensor<1x2x3x4x5x6xf32>) -> tensor<1x2x3x4x5x6xf32>
return %0 : tensor<1x2x3x4x5x6xf32>
}
// -----
// CHECK-LABEL: testSelectV2With4DBroadcasting
func @testSelectV2With4DBroadcasting(%cond : tensor<1x1x3x1xi1>, %arg0 : tensor<1x1x1x4xf32>, %arg1 : tensor<1x2x1x1xf32>) -> tensor<1x2x3x4xf32> {
// CHECK: "tfl.select_v2"(%arg0, %arg1, %arg2)
%0 = "tfl.select_v2"(%cond, %arg0, %arg1): (tensor<1x1x3x1xi1>, tensor<1x1x1x4xf32>, tensor<1x2x1x1xf32>) -> tensor<1x2x3x4xf32>
return %0 : tensor<1x2x3x4xf32>
}
// -----
func @testSelectV2WithWrongBroadcastableArguments(%cond : tensor<3x4xi1>, %arg0 : tensor<2x3x4xf32>, %arg1 : tensor<4x3xf32>) -> tensor<2x3x4xf32> {
// expected-error @+1 {{'tfl.select_v2' op operands don't have broadcast-compatible shapes}}
%0 = "tfl.select_v2"(%cond, %arg0, %arg1): (tensor<3x4xi1>, tensor<2x3x4xf32>, tensor<4x3xf32>) -> tensor<2x3x4xf32>
return %0 : tensor<2x3x4xf32>
}
// -----
// CHECK-LABEL: topk
func @topk(%arg0: tensor<8xf32>, %arg1: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
%0, %1 = "tfl.topk_v2"(%arg0, %arg1) : (tensor<8xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
@ -888,6 +923,27 @@ func @testPadWithInvalidPaddingsRank(tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> te
// -----
func @testPadUnknownPaddings(tensor<2x1x3xf32>, tensor<*xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<*xi32>):
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<*xi32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
// CHECK-LABEL: testPadUnknownPaddings
// CHECK: "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<*xi32>) -> tensor<?xf32>
// CHECK: return
}
// -----
func @testPadUnsupportedPaddings(tensor<*xf32>, tensor<5x3xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<*xf32>, %arg1: tensor<5x3xi32>):
// expected-error @+1 {{'tfl.pad' op failed to verify that the first dim size of the padding argument must be at most 4}}
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<*xf32>, tensor<5x3xi32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// CHECK-LABEL: testPadQuantizedU8
func @testPadQuantizedU8(%arg0: tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<3x2xi32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> {
// CHECK: "tfl.pad"(%arg0, %arg1)
@ -958,6 +1014,29 @@ func @testPadV2WithInvalidConstantScalar(tensor<2x1x3xf32>, tensor<3x2xi32>) ->
// -----
func @testPadV2UnknownPaddings(tensor<2x1x3xf32>, tensor<*xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<*xi32>):
%cst = constant dense<2.0> : tensor<f32>
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<*xi32>, tensor<f32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
// CHECK-LABEL: testPadV2UnknownPaddings
// CHECK: "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<*xi32>, tensor<f32>) -> tensor<?xf32>
// CHECK: return
}
// -----
func @testPadV2UnsupportedPaddings(tensor<*xf32>, tensor<5x3xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<*xf32>, %arg1: tensor<5x3xi32>):
%cst = constant dense<2.0> : tensor<f32>
// expected-error @+1 {{'tfl.padv2' op failed to verify that the first dim size of the padding argument must be at most 4}}
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<*xf32>, tensor<5x3xi32>, tensor<f32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
func @packQuantizedU8(%arg0: tensor<2x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<2x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1>> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform<u8:f32, 0.1>>, tensor<2x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1>>
@ -1322,6 +1401,14 @@ func @transpose(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2xi32>) -> tensor<2x2xi3
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: transpose_with_output_that_has_dynamic_sizes
func @transpose_with_output_that_has_dynamic_sizes(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2xi32>) -> tensor<?x?xi32> {
// CHECK: "tfl.transpose"(%arg0, %arg1)
%0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}
// -----
@ -2070,6 +2157,16 @@ func @testTransposeConv(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %ar
// -----
// CHECK-LABEL: testTransposeConvWithOutputThatHasDynamicSizes
func @testTransposeConvWithOutputThatHasDynamicSizes(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<?x?x?x?xf32> {
// CHECK: "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst)
%cst = constant unit
%0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
}
// -----
func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
// custom op for "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
%0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_option = opaque<"tfl", "0x010000000200000002000000"> : tensor<12xi8>, custom_code = "Convolution2DTransposeBias"} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>

View File

@ -269,6 +269,7 @@ cc_library(
"ir/tf_executor.h",
"ir/tf_ops.h",
"ir/tf_saved_model.h",
"ir/tf_side_effects.h",
"ir/tf_structs.h",
"ir/tf_traits.h",
"ir/tf_verifiers.h",
@ -425,6 +426,7 @@ cc_library(
"transforms/layout_optimization.cc",
"transforms/mark_function_visibility.cc",
"transforms/materialize_mlir_passthrough_op.cc",
"transforms/op_fusion.cc",
"transforms/optimize.cc",
"transforms/optimize_global_tensors.cc",
"transforms/parallel_execute_to_islands.cc",
@ -435,6 +437,7 @@ cc_library(
"transforms/replicate_to_island.cc",
"transforms/resource_device_inference.cc",
"transforms/resource_op_lifting.cc",
"transforms/rewrite_tpu_embedding_ops.cc",
"transforms/shape_inference.cc",
"transforms/shape_inference_pass.cc",
"transforms/sink_constant.cc",
@ -450,6 +453,7 @@ cc_library(
"transforms/tpu_extract_head_tail_outside_compilation.cc",
"transforms/tpu_extract_outside_compilation.cc",
"transforms/tpu_merge_variables_with_execute.cc",
"transforms/tpu_outside_compilation_cluster.cc",
"transforms/tpu_rewrite_pass.cc",
"transforms/tpu_sharding_identification_pass.cc",
"transforms/tpu_space_to_depth_pass.cc",

View File

@ -917,6 +917,30 @@ the feature dimension is the third-to-last.
}];
}
def TF_BiasAddV1Op : TF_Op<"BiasAddV1", [NoSideEffect]> {
let summary = "Adds `bias` to `value`.";
let description = [{
This is a deprecated version of BiasAdd and will be soon removed.
This is a special case of `tf.add` where `bias` is restricted to be 1-D.
Broadcasting is supported, so `value` may have any number of dimensions.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$value,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$bias
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
}
def TF_BitcastOp : TF_Op<"Bitcast", [NoSideEffect]> {
let summary = [{
Bitcasts a tensor from one type to another without copying data.
@ -2946,19 +2970,23 @@ Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type.
}];
let description = [{
Attributes `[min; max]` define the clamping range for the `inputs` data.
`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
then de-quantized and output as floats in `[min; max]` interval.
`num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive.
Attributes
* `[min; max]` define the clamping range for the `inputs` data.
* `inputs` values are quantized into the quantization range (
`[0; 2^num_bits - 1]` when `narrow_range` is false and `[1; 2^num_bits - 1]`
when it is true) and then de-quantized and output as floats in `[min; max]`
interval.
* `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive.
Before quantization, `min` and `max` values are adjusted with the following
logic.
It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values,
the behavior can be unexpected:
If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`.
If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`.
If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `,
* If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`.
* If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`.
* If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `,
`min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`.
Quantization is called fake since the output is still in floating point.
@ -2984,25 +3012,30 @@ Quantization is called fake since the output is still in floating point.
def TF_FakeQuantWithMinMaxVarsOp : TF_Op<"FakeQuantWithMinMaxVars", [NoSideEffect]> {
let summary = [{
Fake-quantize the 'inputs' tensor of type float via global float scalars `min`
Fake-quantize the 'inputs' tensor of type float via global float scalars
}];
let description = [{
and `max` to 'outputs' tensor of same shape as `inputs`.
Fake-quantize the `inputs` tensor of type float via global float scalars
`min` and `max` to `outputs` tensor of same shape as `inputs`.
`[min; max]` define the clamping range for the `inputs` data.
`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
then de-quantized and output as floats in `[min; max]` interval.
`num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive.
Attributes
* `[min; max]` define the clamping range for the `inputs` data.
* `inputs` values are quantized into the quantization range (
`[0; 2^num_bits - 1]` when `narrow_range` is false and `[1; 2^num_bits - 1]`
when it is true) and then de-quantized and output as floats in `[min; max]`
interval.
* `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive.
Before quantization, `min` and `max` values are adjusted with the following
logic.
It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values,
the behavior can be unexpected:
If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`.
If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`.
If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `,
* If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`.
* If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`.
* If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `,
`min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`.
This operation has a gradient and thus allows for training `min` and `max`
@ -3029,26 +3062,31 @@ values.
def TF_FakeQuantWithMinMaxVarsPerChannelOp : TF_Op<"FakeQuantWithMinMaxVarsPerChannel", [NoSideEffect]> {
let summary = [{
Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`,
Fake-quantize the 'inputs' tensor of type float via per-channel floats
}];
let description = [{
`[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]`
to 'outputs' tensor of same shape as `inputs`.
Fake-quantize the `inputs` tensor of type float per-channel and one of the
shapes: `[d]`, `[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max`
of shape `[d]` to `outputs` tensor of same shape as `inputs`.
`[min; max]` define the clamping range for the `inputs` data.
`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]`
when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and
then de-quantized and output as floats in `[min; max]` interval.
`num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive.
Attributes
* `[min; max]` define the clamping range for the `inputs` data.
* `inputs` values are quantized into the quantization range (
`[0; 2^num_bits - 1]` when `narrow_range` is false and `[1; 2^num_bits - 1]`
when it is true) and then de-quantized and output as floats in `[min; max]`
interval.
* `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive.
Before quantization, `min` and `max` values are adjusted with the following
logic.
It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values,
the behavior can be unexpected:
If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`.
If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`.
If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `,
* If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`.
* If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`.
* If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `,
`min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`.
This operation has a gradient and thus allows for training `min` and `max`
@ -6474,7 +6512,7 @@ operation.
}];
let arguments = (ins
TF_ResourceTensor:$resource
Arg<TF_ResourceTensor, "", [TF_VariableRead]>:$resource
);
let results = (outs
@ -10801,6 +10839,38 @@ def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF__FusedConv2DOp : TF_Op<"_FusedConv2D", [NoSideEffect]> {
let summary = [{
*NOTE*: Do not invoke this operator directly in Python. Grappler is
}];
let description = [{
expected to create these operators.
}];
let arguments = (ins
TF_F32OrF64Tensor:$input,
TF_F32OrF64Tensor:$filter,
Variadic<TF_F32OrF64Tensor>:$args,
I64ArrayAttr:$strides,
TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding,
DefaultValuedAttr<I64ArrayAttr, "{}">:$explicit_paddings,
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
DefaultValuedAttr<I64ArrayAttr, "{1, 1, 1, 1}">:$dilations,
DefaultValuedAttr<BoolAttr, "true">:$use_cudnn_on_gpu,
DefaultValuedAttr<StrArrayAttr, "{}">:$fused_ops,
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon
);
let results = (outs
TF_F32OrF64Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandSizeAttr num_args = TF_DerivedOperandSizeAttr<2>;
}
def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> {
let summary = "A host-side computation called from a TPU device.";

View File

@ -80,6 +80,26 @@ class TF_AllTypesMatch<list<string> names> :
TF_AllTypesMatchPred<
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
//===----------------------------------------------------------------------===//
// TensorFlow op side effects
//===----------------------------------------------------------------------===//
class TF_ResourceBase<string resourceKind> :
Resource<!strconcat("::mlir::TF::ResourceEffects::", resourceKind)> {
}
def TF_VariableResource : TF_ResourceBase<"Variable">;
def TF_StackResource : TF_ResourceBase<"Stack">;
def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">;
def TF_VariableRead : MemRead<TF_VariableResource>;
def TF_StackRead : MemRead<TF_StackResource>;
def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>;
def TF_VariableWrite : MemWrite<TF_VariableResource>;
def TF_StackWrite : MemWrite<TF_StackResource>;
def TF_TensorArrayWrite : MemWrite<TF_TensorArrayResource>;
//===----------------------------------------------------------------------===//
// TensorFlow op definitions
//===----------------------------------------------------------------------===//

View File

@ -59,6 +59,7 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/core/platform/logging.h"
@ -754,6 +755,15 @@ static LogicalResult Verify(BiasAddGradOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// BiasAddV1Op
//===----------------------------------------------------------------------===//
void BiasAddV1Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<BiasAddV1ToBiasAdd>(context);
}
//===----------------------------------------------------------------------===//
// BitcastOp
//===----------------------------------------------------------------------===//

View File

@ -358,6 +358,41 @@ def TF_LegacyCallOp : TF_Op<"LegacyCall",
}];
}
def TF_ParseExampleOp : TF_Op<"ParseExample",
[NoSideEffect,
AttrSizedResultSegments,
AttrSizedOperandSegments]> {
let summary =
"Transforms a vector of tf.Example protos (as strings) into typed tensors.";
let arguments = (ins
TF_StrTensor:$serialized,
TF_StrTensor:$names,
Variadic<TF_StrTensor>:$sparse_keys,
Variadic<TF_StrTensor>:$dense_keys,
Variadic<TensorOf<[F32, I64, TF_Str]>>:$dense_defaults,
TF_ShapeAttrArray:$dense_shapes,
I32ElementsAttr:$result_segment_sizes,
I32ElementsAttr:$operand_segment_sizes
);
let results = (outs
Variadic<I64Tensor>:$sparse_indices, // len(sparse_types)
Variadic<TensorOf<[F32, I64, TF_Str]>>:$sparse_values, // len(sparse_types)
Variadic<I64Tensor>:$sparse_shapes, // len(sparse_types)
Variadic<TensorOf<[F32, I64, TF_Str]>>:$dense_values // len(Tdense)
);
TF_DerivedOperandSizeAttr Nsparse = TF_DerivedOperandSizeAttr<2>;
TF_DerivedOperandSizeAttr Ndense = TF_DerivedOperandSizeAttr<3>;
TF_DerivedOperandTypeListAttr Tdense = TF_DerivedOperandTypeListAttr<4>;
TF_DerivedResultTypeListAttr sparse_types = TF_DerivedResultTypeListAttr<1>;
let verifier = ?;
}
def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2",
[NoSideEffect,
AttrSizedResultSegments]> {

View File

@ -0,0 +1,42 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the side effect definition file for TensorFlow.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
namespace mlir {
namespace TF {
namespace ResourceEffects {
struct Variable : ::mlir::SideEffects::Resource::Base<Variable> {
StringRef getName() final { return "Variable"; }
};
struct Stack : ::mlir::SideEffects::Resource::Base<Stack> {
StringRef getName() final { return "Stack"; }
};
struct TensorArray : ::mlir::SideEffects::Resource::Base<TensorArray> {
StringRef getName() final { return "TensorArray"; }
};
} // namespace ResourceEffects
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_

View File

@ -19,7 +19,7 @@ module attributes {tf.versions = {producer = 888 : i32}} {
// CHECK-LABEL: func @_func
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32>,
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32> {tf_device.is_same_data_across_replicas = true}
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32> {xla_hlo.is_same_data_across_replicas}
// CHECK-SAME: %[[ARG2:.*]]: tensor<?xi32>)
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>, %arg2: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
@ -54,9 +54,9 @@ module attributes {tf.versions = {producer = 888 : i32}} {
}
// CHECK-LABEL: func @_func
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32> {tf_device.is_same_data_across_replicas = true},
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32> {xla_hlo.is_same_data_across_replicas},
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32>,
// CHECK-SAME: %[[ARG2:.*]]: tensor<!tf.resource<tensor<?xi32>>> {tf_device.is_same_data_across_replicas = true}
// CHECK-SAME: %[[ARG2:.*]]: tensor<!tf.resource<tensor<?xi32>>> {xla_hlo.is_same_data_across_replicas}
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>, %arg2: tensor<!tf.resource<tensor<?xi32>>>) -> tensor<?xi32> {
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
@ -78,7 +78,7 @@ module attributes {tf.versions = {producer = 888 : i32}} {
}
// CHECK-LABEL: func @_func
// CHECK-NOT: tf_device.is_same_data_across_replicas
// CHECK-NOT: xla_hlo.is_same_data_across_replicas
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>

View File

@ -344,3 +344,31 @@ func @switchn_control_input(%arg1: tensor<i32>) {
}
return
}
// CHECK-LABEL: func @single_op_island_forward_block_arg
// CHECK: %[[CONST:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"
// CHECK: tf_executor.fetch %[[CONST]], %arg0
func @single_op_island_forward_block_arg(%arg0: tensor<?x?x?x?xbf16>) -> (tensor<2048xf32>, tensor<?x?x?x?xbf16>) {
%0:2 = tf_executor.graph {
%outputs:2, %control = tf_executor.island {
%1 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2048xf32>} : () -> tensor<2048xf32>
tf_executor.yield %1, %arg0 : tensor<2048xf32>, tensor<?x?x?x?xbf16>
}
tf_executor.fetch %outputs#0, %outputs#1 : tensor<2048xf32>, tensor<?x?x?x?xbf16>
}
return %0#0, %0#1 : tensor<2048xf32>, tensor<?x?x?x?xbf16>
}
// CHECK-LABEL: func @single_op_island_duplicate_result
// CHECK: %[[CONST:.*]], %{{.*}} = tf_executor.island wraps "tf.Const"
// CHECK: tf_executor.fetch %[[CONST]], %[[CONST]]
func @single_op_island_duplicate_result() -> (tensor<2048xf32>, tensor<2048xf32>) {
%0:2 = tf_executor.graph {
%outputs:2, %control = tf_executor.island {
%1 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2048xf32>} : () -> tensor<2048xf32>
tf_executor.yield %1, %1 : tensor<2048xf32>, tensor<2048xf32>
}
tf_executor.fetch %outputs#0, %outputs#1 : tensor<2048xf32>, tensor<2048xf32>
}
return %0#0, %0#1 : tensor<2048xf32>, tensor<2048xf32>
}

View File

@ -34,6 +34,13 @@ func @testBatchMatMulV2ToMatMul(%arg0: tensor<4x3xf32>, %arg1: tensor<4x5xf32>)
// CHECK: return %0
}
// CHECK-LABEL: testBiasAddV1ToBiasAdd
func @testBiasAddV1ToBiasAdd(%arg0: tensor<*xf32>, %arg1: tensor<128xf32>) -> tensor<*xf32> {
// CHECK: "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
%0 = "tf.BiasAddV1"(%arg0, %arg1) : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
return %0: tensor<*xf32>
}
// CHECK-LABEL: func @testLeakyRelu
func @testLeakyRelu(%arg0 : tensor<16xf32>) -> (tensor<16xf32>) {
%2 = "tf.LeakyRelu"(%arg0) {alpha = 1.0 : f32} : (tensor<16xf32>) -> tensor<16xf32>
@ -505,15 +512,15 @@ func @testReadVariableOpOfCastMultiUse(%arg0: tensor<!tf.resource<tensor<f32>>>)
}
// CHECK-LABEL: testMultiReadVariableOpsOfCast
func @testMultiReadVariableOpsOfCast(%arg0: tensor<!tf.resource<tensor<f32>>>) -> tensor<f32> {
func @testMultiReadVariableOpsOfCast(%arg0: tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>, tensor<f32>) {
%0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource>
%1 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>) -> tensor<f32>
%2 = "tf.ReadVariableOp"(%0) : (tensor<*x!tf.resource>) -> tensor<f32>
return %2: tensor<f32>
return %1, %2: tensor<f32>, tensor<f32>
// CHECK: %0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
// CHECK: %1 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
// CHECK: return %1
// CHECK: return %0, %1
}
// CHECK-LABEL: testRankOfRankedTensor

View File

@ -1,73 +1,12 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0 -tf-input-data-types=DT_STRING -tf-input-shapes=32 -tf-output-arrays=ParseExample/ParseExampleV2:0,ParseExample/ParseExampleV2:7 -o - | FileCheck %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false -tf-output-arrays=result %s | FileCheck %s
# CHECK: %[[parse_example:.*]]:8, %[[parse_example_control:.*]] = tf_executor.island wraps "tf.ParseExampleV2"(%arg0,
# CHECK: result_segment_sizes = dense<[2, 2, 2, 2, 0, 0]> : vector<6xi32>
# CHECK: tf_executor.fetch %[[parse_example]]#0, %[[parse_example]]#7 : tensor<*xi64>, tensor<*xf32>
# CHECK: %[[output:.*]], %[[control:.*]]tf_executor.island wraps "tf.ParseExample"
# CHECK: operand_segment_sizes = dense<[1, 1, 0, 1, 1]> : vector<5xi32>
# CHECK: result_segment_sizes = dense<[0, 0, 0, 1]> : vector<4xi32>
# CHECK: tf_executor.fetch %[[output]] : tensor<*xi64>
node {
name: "input0"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
}
node {
name: "ParseExample/Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
}
}
}
}
}
}
node {
name: "ParseExample/Const_1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
}
}
}
}
}
}
node {
name: "ParseExample/ParseExampleV2/names"
name: "serilaized"
op: "Const"
attr {
key: "dtype"
@ -82,14 +21,16 @@ node {
dtype: DT_STRING
tensor_shape {
dim {
size: 1
}
}
string_val: ""
}
}
}
}
node {
name: "ParseExample/ParseExampleV2/sparse_keys"
name: "Const"
op: "Const"
attr {
key: "dtype"
@ -104,17 +45,16 @@ node {
dtype: DT_STRING
tensor_shape {
dim {
size: 2
size: 1
}
}
string_val: "feature_key3"
string_val: "feature_key4"
string_val: "value"
}
}
}
}
node {
name: "ParseExample/ParseExampleV2/dense_keys"
name: "Const_1"
op: "Const"
attr {
key: "dtype"
@ -128,54 +68,57 @@ node {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 2
}
}
string_val: "feature_key1"
string_val: "feature_key2"
string_val: "value"
}
}
}
}
node {
name: "ParseExample/ParseExampleV2/ragged_keys"
name: "Const_2"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
type: DT_INT64
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
dtype: DT_INT64
tensor_shape {
dim {
}
}
int64_val: -1
}
}
}
}
node {
name: "ParseExample/ParseExampleV2"
op: "ParseExampleV2"
input: "input0"
input: "ParseExample/ParseExampleV2/names"
input: "ParseExample/ParseExampleV2/sparse_keys"
input: "ParseExample/ParseExampleV2/dense_keys"
input: "ParseExample/ParseExampleV2/ragged_keys"
input: "ParseExample/Const"
input: "ParseExample/Const_1"
name: "result"
op: "ParseExample"
input: "serilaized"
input: "Const"
input: "Const_1"
input: "Const_2"
attr {
key: "Ndense"
value {
i: 1
}
}
attr {
key: "Nsparse"
value {
i: 0
}
}
attr {
key: "Tdense"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
type: DT_INT64
}
}
}
@ -184,29 +127,10 @@ node {
value {
list {
shape {
dim {
size: 1
}
}
shape {
}
}
}
}
attr {
key: "num_sparse"
value {
i: 2
}
}
attr {
key: "ragged_split_types"
value {
list {
}
}
}
attr {
key: "ragged_value_types"
value {
list {
}
}
}
@ -214,12 +138,12 @@ node {
key: "sparse_types"
value {
list {
type: DT_STRING
type: DT_INT64
}
}
}
}
versions {
producer: 175
library {
}
versions {
producer: 413
}

View File

@ -0,0 +1,225 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0 -tf-input-data-types=DT_STRING -tf-input-shapes=32 -tf-output-arrays=ParseExample/ParseExampleV2:0,ParseExample/ParseExampleV2:7 -o - | FileCheck %s
# CHECK: %[[parse_example:.*]]:8, %[[parse_example_control:.*]] = tf_executor.island wraps "tf.ParseExampleV2"(%arg0,
# CHECK: result_segment_sizes = dense<[2, 2, 2, 2, 0, 0]> : vector<6xi32>
# CHECK: tf_executor.fetch %[[parse_example]]#0, %[[parse_example]]#7 : tensor<*xi64>, tensor<*xf32>
node {
name: "input0"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
}
node {
name: "ParseExample/Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
}
}
}
}
}
}
node {
name: "ParseExample/Const_1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
}
}
}
}
}
}
node {
name: "ParseExample/ParseExampleV2/names"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
}
}
}
}
}
}
node {
name: "ParseExample/ParseExampleV2/sparse_keys"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 2
}
}
string_val: "feature_key3"
string_val: "feature_key4"
}
}
}
}
node {
name: "ParseExample/ParseExampleV2/dense_keys"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 2
}
}
string_val: "feature_key1"
string_val: "feature_key2"
}
}
}
}
node {
name: "ParseExample/ParseExampleV2/ragged_keys"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
}
}
}
}
}
}
node {
name: "ParseExample/ParseExampleV2"
op: "ParseExampleV2"
input: "input0"
input: "ParseExample/ParseExampleV2/names"
input: "ParseExample/ParseExampleV2/sparse_keys"
input: "ParseExample/ParseExampleV2/dense_keys"
input: "ParseExample/ParseExampleV2/ragged_keys"
input: "ParseExample/Const"
input: "ParseExample/Const_1"
attr {
key: "Tdense"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
}
}
}
attr {
key: "dense_shapes"
value {
list {
shape {
}
shape {
}
}
}
}
attr {
key: "num_sparse"
value {
i: 2
}
}
attr {
key: "ragged_split_types"
value {
list {
}
}
}
attr {
key: "ragged_value_types"
value {
list {
}
}
}
attr {
key: "sparse_types"
value {
list {
type: DT_STRING
type: DT_INT64
}
}
}
}
versions {
producer: 175
}

View File

@ -713,6 +713,16 @@ func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> t
return %0 : tensor<1x1xf32>
}
func @broadcast_in_dim_tf_style(%arg0: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32>
return %0 : tensor<3x8x8x16xf32>
}
func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32>
return %0 : tensor<3x8x8x16xf32>
}
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK-LABEL: func @biasAdd_NHWC(
@ -1570,3 +1580,19 @@ func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> t
// CHECK: [[VAL_394:%.*]] = "tf.MatMul"([[VAL_392]], [[VAL_393]]) {transpose_a = false, transpose_b = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
// CHECK: return [[VAL_394]] : tensor<1x1xf32>
// CHECK: }
// CHECK-LABEL: func @broadcast_in_dim_tf_style(
// CHECK-SAME: [[VAL_395:%.*]]: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> {
// CHECK: [[VAL_396:%.*]] = constant dense<[3, 8, 8, 16]> : tensor<4xi64>
// CHECK: [[VAL_397:%.*]] = "tf.BroadcastTo"([[VAL_395]], [[VAL_396]]) : (tensor<8x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32>
// CHECK: return [[VAL_397]] : tensor<3x8x8x16xf32>
// CHECK: }
// CHECK-LABEL: func @broadcast_in_dim_general_case(
// CHECK-SAME: [[VAL_398:%.*]]: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> {
// CHECK: [[VAL_399:%.*]] = constant dense<[3, 1, 1, 16]> : tensor<4xi64>
// CHECK: [[VAL_400:%.*]] = "tf.Reshape"([[VAL_398]], [[VAL_399]]) : (tensor<3x1x16xf32>, tensor<4xi64>) -> tensor<3x1x1x16xf32>
// CHECK: [[VAL_401:%.*]] = constant dense<[3, 8, 8, 16]> : tensor<4xi64>
// CHECK: [[VAL_402:%.*]] = "tf.BroadcastTo"([[VAL_400]], [[VAL_401]]) : (tensor<3x1x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32>
// CHECK: return [[VAL_402]] : tensor<3x8x8x16xf32>
// CHECK: }

View File

@ -1,83 +1,61 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s --dump-input-on-failure
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 175 : i32}} {
func @main(%arg0: tensor<32x!tf.string>) -> (tensor<?x2xi64>) attributes {tf.entry_function = {inputs = "input0", outputs = "ParseExample/ParseExampleV2"}} {
// CHECK: name: "tf.ParseExample"
// CHECK-NEXT: op: "ParseExample"
// CHECK-NEXT: input: "tf.Const3"
// CHECK-NEXT: input: "tf.Const"
// CHECK-NEXT: input: "tf.Const1"
// CHECK-NEXT: input: "tf.Const2"
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "Ndense"
// CHECK-NEXT: value {
// CHECK-NEXT: i: 1
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "Nsparse"
// CHECK-NEXT: value {
// CHECK-NEXT: i: 0
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "Tdense"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: type: DT_INT64
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK: key: "dense_shapes"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: shape {
// CHECK-NEXT: dim {
// CHECK-NEXT: size: 1
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "sparse_types"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 413 : i32}} {
func @main() -> tensor<*xi64> attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "result"}} {
%0 = tf_executor.graph {
%outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", dtype = f32, value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
%outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", dtype = f32, value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
%outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = dense<""> : tensor<2x!tf.string>} : () -> tensor<2x!tf.string>
%outputs_4, %control_5 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = dense<""> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string>
%outputs_6, %control_7 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = dense<""> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string>
%outputs_8, %control_9 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = dense<""> : tensor<2x!tf.string>} : () -> tensor<2x!tf.string>
%outputs_10:8, %control_11 = tf_executor.island wraps "tf.ParseExampleV2"(%arg0, %outputs_4, %outputs_8, %outputs_2, %outputs_6, %outputs, %outputs_0) {Tdense = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], dense_shapes = [#tf.shape<>, #tf.shape<>], device = "", num_sparse = 2 : i64, ragged_split_types = [], ragged_value_types = [], result_segment_sizes = dense<[2, 2, 2, 2, 0, 0]> : vector<6xi32>, sparse_types = ["tfdtype$DT_STRING", "tfdtype$DT_INT64"]} : (tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<0xf32>, tensor<0xf32>) -> (tensor<?x2xi64>, tensor<?x2xi64>, tensor<?x!tf.string>, tensor<?xi64>, tensor<2xi64>, tensor<2xi64>, tensor<32xf32>, tensor<32xf32>) loc("ParseExample")
// CHECK: name: "ParseExample"
// CHECK-NEXT: op: "ParseExampleV2"
// CHECK-NEXT: input: "input0"
// CHECK-NEXT: input: "tf.Const3"
// CHECK-NEXT: input: "tf.Const5"
// CHECK-NEXT: input: "tf.Const2"
// CHECK-NEXT: input: "tf.Const4"
// CHECK-NEXT: input: "tf.Const"
// CHECK-NEXT: input: "tf.Const1"
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "Tdense"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: type: DT_FLOAT
// CHECK-NEXT: type: DT_FLOAT
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK: key: "dense_shapes"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: shape {
// CHECK-NEXT: }
// CHECK-NEXT: shape {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "num_sparse"
// CHECK-NEXT: value {
// CHECK-NEXT: i: 2
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "ragged_split_types"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "ragged_value_types"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "sparse_types"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: type: DT_STRING
// CHECK-NEXT: type: DT_INT64
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
tf_executor.fetch %outputs_10#0 : tensor<?x2xi64>
%outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<"value"> : tensor<1x!tf.string>} : () -> tensor<1x!tf.string>
%outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<"value"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<-1> : tensor<i64>} : () -> tensor<i64>
%outputs_4, %control_5 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<""> : tensor<1x!tf.string>} : () -> tensor<1x!tf.string>
%outputs_6, %control_7 = tf_executor.island wraps "tf.ParseExample"(%outputs_4, %outputs, %outputs_0, %outputs_2) {dense_shapes = [#tf.shape<1>], device = "", operand_segment_sizes = dense<[1, 1, 0, 1, 1]> : vector<5xi32>, result_segment_sizes = dense<[0, 0, 0, 1]> : vector<4xi32>} : (tensor<1x!tf.string>, tensor<1x!tf.string>, tensor<!tf.string>, tensor<i64>) -> tensor<*xi64>
tf_executor.fetch %outputs_6 : tensor<*xi64>
}
return %0#0 : tensor<?x2xi64>
// CHECK: name: "ParseExample/ParseExampleV2"
// CHECK-NEXT: op: "_Retval"
// CHECK-NEXT: input: "ParseExample"
return %0 : tensor<*xi64>
}
}

View File

@ -0,0 +1,83 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s --dump-input-on-failure
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 175 : i32}} {
func @main(%arg0: tensor<32x!tf.string>) -> (tensor<?x2xi64>) attributes {tf.entry_function = {inputs = "input0", outputs = "ParseExample/ParseExampleV2"}} {
%0 = tf_executor.graph {
%outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", dtype = f32, value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
%outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", dtype = f32, value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
%outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = dense<""> : tensor<2x!tf.string>} : () -> tensor<2x!tf.string>
%outputs_4, %control_5 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = dense<""> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string>
%outputs_6, %control_7 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = dense<""> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string>
%outputs_8, %control_9 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = dense<""> : tensor<2x!tf.string>} : () -> tensor<2x!tf.string>
%outputs_10:8, %control_11 = tf_executor.island wraps "tf.ParseExampleV2"(%arg0, %outputs_4, %outputs_8, %outputs_2, %outputs_6, %outputs, %outputs_0) {Tdense = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], dense_shapes = [#tf.shape<>, #tf.shape<>], device = "", num_sparse = 2 : i64, ragged_split_types = [], ragged_value_types = [], result_segment_sizes = dense<[2, 2, 2, 2, 0, 0]> : vector<6xi32>, sparse_types = ["tfdtype$DT_STRING", "tfdtype$DT_INT64"]} : (tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<0xf32>, tensor<0xf32>) -> (tensor<?x2xi64>, tensor<?x2xi64>, tensor<?x!tf.string>, tensor<?xi64>, tensor<2xi64>, tensor<2xi64>, tensor<32xf32>, tensor<32xf32>) loc("ParseExample")
// CHECK: name: "ParseExample"
// CHECK-NEXT: op: "ParseExampleV2"
// CHECK-NEXT: input: "input0"
// CHECK-NEXT: input: "tf.Const3"
// CHECK-NEXT: input: "tf.Const5"
// CHECK-NEXT: input: "tf.Const2"
// CHECK-NEXT: input: "tf.Const4"
// CHECK-NEXT: input: "tf.Const"
// CHECK-NEXT: input: "tf.Const1"
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "Tdense"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: type: DT_FLOAT
// CHECK-NEXT: type: DT_FLOAT
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK: key: "dense_shapes"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: shape {
// CHECK-NEXT: }
// CHECK-NEXT: shape {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "num_sparse"
// CHECK-NEXT: value {
// CHECK-NEXT: i: 2
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "ragged_split_types"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "ragged_value_types"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: attr {
// CHECK-NEXT: key: "sparse_types"
// CHECK-NEXT: value {
// CHECK-NEXT: list {
// CHECK-NEXT: type: DT_STRING
// CHECK-NEXT: type: DT_INT64
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
tf_executor.fetch %outputs_10#0 : tensor<?x2xi64>
}
return %0#0 : tensor<?x2xi64>
// CHECK: name: "ParseExample/ParseExampleV2"
// CHECK-NEXT: op: "_Retval"
// CHECK-NEXT: input: "ParseExample"
}
}

View File

@ -0,0 +1,109 @@
// RUN: tf-opt %s -op-fusion | FileCheck %s --dump-input-on-failure
//===----------------------------------------------------------------------===//
// Conv2D + BiasAdd + <Activation> fusions.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: conv2DBiasAdd_noActivation
func @conv2DBiasAdd_noActivation(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>) {
// CHECK: %[[VAL_0:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) {data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd"], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
// CHECK: %[[VAL_1:.*]] = "tf.Identity"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[VAL_1]]
%0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
%1 = "tf.BiasAdd"(%0, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
%2 = "tf.Identity"(%1) : (tensor<*xf32>) -> tensor<*xf32>
return %2 : tensor<*xf32>
}
// CHECK-LABEL: conv2DBiasAdd_reluActivation
func @conv2DBiasAdd_reluActivation(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>) {
// CHECK: %[[VAL_0:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) {data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd", "Relu"], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
// CHECK: %[[VAL_1:.*]] = "tf.Identity"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[VAL_1]]
%0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
%1 = "tf.BiasAdd"(%0, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
%2 = "tf.Relu"(%1) : (tensor<*xf32>) -> tensor<*xf32>
%3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32>
return %3 : tensor<*xf32>
}
// CHECK-LABEL: conv2DBiasAdd_relu6Activation
func @conv2DBiasAdd_relu6Activation(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>) {
// CHECK: %[[VAL_0:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) {data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd", "Relu6"], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
// CHECK: %[[VAL_1:.*]] = "tf.Identity"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[VAL_1]]
%0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
%1 = "tf.BiasAdd"(%0, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
%2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32>
%3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32>
return %3 : tensor<*xf32>
}
// CHECK-LABEL: conv2DBiasAdd_eluActivation
func @conv2DBiasAdd_eluActivation(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>) {
// CHECK: %[[VAL_0:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) {data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd", "Elu"], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
// CHECK: %[[VAL_1:.*]] = "tf.Identity"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[VAL_1]]
%0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
%1 = "tf.BiasAdd"(%0, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
%2 = "tf.Elu"(%1) : (tensor<*xf32>) -> tensor<*xf32>
%3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32>
return %3 : tensor<*xf32>
}
// CHECK-LABEL: conv2DBiasAdd_convMultipleUses
func @conv2DBiasAdd_convMultipleUses(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
// CHECK-NOT: "tf._FusedConv2D"
%0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
%1 = "tf.BiasAdd"(%0, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
%2 = "tf.Elu"(%1) : (tensor<*xf32>) -> tensor<*xf32>
%3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32>
%4 = "tf.Identity"(%0) : (tensor<*xf32>) -> tensor<*xf32>
return %3, %4 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: conv2DBiasAdd_biasAddMultipleUse
func @conv2DBiasAdd_biasAddMultipleUse(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
// CHECK-DAG: %[[VAL:.*]] = "tf._FusedConv2D"(%arg2, %arg1, %arg0) {data_format = "NHWC", dilations = [1, 1, 1, 1], epsilon = 0.000000e+00 : f32, explicit_paddings = [], fused_ops = ["BiasAdd"], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>, tensor<128xf32>) -> tensor<*xf32>
// CHECK-DAG: %[[VAL_0:.*]] = "tf.Elu"(%[[VAL]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-DAG: %[[VAL_1:.*]] = "tf.Identity"(%[[VAL_0]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-DAG: %[[VAL_2:.*]] = "tf.Identity"(%[[VAL]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[VAL_1]], %[[VAL_2]]
%0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
%1 = "tf.BiasAdd"(%0, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
%2 = "tf.Elu"(%1) : (tensor<*xf32>) -> tensor<*xf32>
%3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32>
%4 = "tf.Identity"(%1) : (tensor<*xf32>) -> tensor<*xf32>
return %3, %4 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: conv2D_noFusion
func @conv2D_noFusion(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>) {
// CHECK-NOT: "tf._FusedConv2D"
%0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
%2 = "tf.Elu"(%0) : (tensor<*xf32>) -> tensor<*xf32>
%3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32>
return %3 : tensor<*xf32>
}
// CHECK-LABEL: conv2D_noFusion1
func @conv2D_noFusion1(%arg0: tensor<*xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>) {
// CHECK-NOT: "tf._FusedConv2D"
%0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
// The result of the conv must be the first input to BiasAdd to be fusable.
%1 = "tf.BiasAdd"(%arg0, %0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%2 = "tf.Elu"(%1) : (tensor<*xf32>) -> tensor<*xf32>
%3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32>
return %3 : tensor<*xf32>
}
// CHECK-LABEL: conv2D_dataFormatMismatch
func @conv2D_dataFormatMismatch(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128xf32>, %arg2: tensor<8x32x32x3xf32>) -> (tensor<*xf32>) {
// CHECK-NOT: "tf._FusedConv2D"
%0 = "tf.Conv2D"(%arg2, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<8x32x32x3xf32>, tensor<1x1x3x128xf32>) -> tensor<*xf32>
// The result of the conv must be the first input to BiasAdd to be fusable.
%1 = "tf.BiasAdd"(%0, %arg0) {data_format = "NCHW"} : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32>
%2 = "tf.Elu"(%1) : (tensor<*xf32>) -> tensor<*xf32>
%3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32>
return %3 : tensor<*xf32>
}

View File

@ -0,0 +1,42 @@
// RUN: tf-opt -tf-rewrite-tpu-embedding-ops %s | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: func @recv_tpu_embedding_activations
func @recv_tpu_embedding_activations() -> (tensor<512x256xf32>) {
// CHECK: %[[DATA:.*]] = "tf._RecvTPUEmbeddingDeduplicationData"() {config = {{.*}}} : () -> tensor<!tf.variant>
// CHECK: %[[RESULT:.*]] = "tf._RecvTPUEmbeddingActivations"(%[[DATA]]) {config = {{.*}}} : (tensor<!tf.variant>) -> tensor<512x256xf32>
// CHECK: return %[[RESULT]]
// CHECK-NOT: tf.RecvTPUEmbeddingActivations
// CHECK-NOT: tf.SendTPUEmbeddingGradients
%0 = "tf.RecvTPUEmbeddingActivations"() {config = "\0A%\0A\0Dwatches_table\10\F5\03\18\80\02 \01*\0C\1A\00j\05\0D\00\00\80?\88\01\01\10\02\18\80\04 \01(\02"} : () -> tensor<512x256xf32>
return %0 : tensor<512x256xf32>
}
// CHECK-LABEL: func @send_tpu_embedding_gradients
func @send_tpu_embedding_gradients(%arg0: tensor<512x256xf32>) -> () {
// CHECK: %[[DATA:.*]] = "tf._RecvTPUEmbeddingDeduplicationData"() {config = {{.*}}} : () -> tensor<!tf.variant>
// CHECK: "tf._SendTPUEmbeddingGradients"(%arg0, %[[DATA]]) {config = {{.*}}, operand_segment_sizes = dense<[1, 0, 1]> : vector<3xi32>} : (tensor<512x256xf32>, tensor<!tf.variant>) -> ()
// CHECK-NOT: tf.SendTPUEmbeddingGradients
// CHECK-NOT: tf.RecvTPUEmbeddingActivations
"tf.SendTPUEmbeddingGradients"(%arg0) {config = "\0A%\0A\0Dwatches_table\10\F5\03\18\80\02 \01*\0C\1A\00j\05\0D\00\00\80?\88\01\01\10\02\18\80\04 \01(\02", operand_segment_sizes = dense<[1, 0]> : vector<2xi32>} : (tensor<512x256xf32>) -> ()
return
}
// CHECK-LABEL: func @recv_send_ops
func @recv_send_ops() -> () {
// CHECK: %[[DATA:.*]] = "tf._RecvTPUEmbeddingDeduplicationData"()
// CHECK: %[[ACTIVATIONS:.*]] = "tf._RecvTPUEmbeddingActivations"(%[[DATA]])
// CHECK: "tf._SendTPUEmbeddingGradients"(%[[ACTIVATIONS]], %[[DATA]])
%0 = "tf.RecvTPUEmbeddingActivations"() {config = "\0A%\0A\0Dwatches_table\10\F5\03\18\80\02 \01*\0C\1A\00j\05\0D\00\00\80?\88\01\01\10\02\18\80\04 \01(\02"} : () -> tensor<512x256xf32>
"tf.SendTPUEmbeddingGradients"(%0) {config = "\0A%\0A\0Dwatches_table\10\F5\03\18\80\02 \01*\0C\1A\00j\05\0D\00\00\80?\88\01\01\10\02\18\80\04 \01(\02", operand_segment_sizes = dense<[1, 0]> : vector<2xi32>} : (tensor<512x256xf32>) -> ()
return
}
// CHECK-LABEL: func @no_embedding_ops
func @no_embedding_ops(%arg0: tensor<2x2xf32>) -> (tensor<2x2xf32>) {
// CHECK: tf.Add
%0 = "tf.Add"(%arg0, %arg0) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}

View File

@ -280,7 +280,7 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
// CHECK: tf_executor.Exit
// CHECK-SAME: : tensor<?x?x?xf32>
// CHECK: tf_executor.LoopCond
// CHECK-SAME: : tensor<*xi1>
// CHECK-SAME: tensor<i1>
%merge:3 = "tf_executor.Merge"(%island#0, %arg1) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<i32>, !tf_executor.control)
%switch:3 = "tf_executor.Switch"(%island#0, %arg2) : (tensor<?x?x?xf32>, tensor<i1>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, !tf_executor.control)
%switchn:3 = "tf_executor.SwitchN"(%island#0, %arg3) {num_outs = 2} : (tensor<?x?x?xf32>, tensor<i32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>, !tf_executor.control)
@ -416,4 +416,21 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
%2 = "tf.Add"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
return
}
// CHECK-LABEL: cast_at_end(%arg0:
// CHECK-SAME: tensor<16x194x199x4xui8>, tensor<16x194x199x4xi8>, tensor<*xi8>
func @cast_at_end(%arg0: tensor<16x194x199x4xf32>, %arg1: tensor<16x194x199x4xi8>) -> (tensor<*xui8>, tensor<*xi8>, tensor<*xi8>) {
// CHECK: %[[CAST_RESULT_0:.*]] = "tf.Cast"(%arg0)
// CHECK-SAME: (tensor<16x194x199x4xf32>) -> tensor<16x194x199x4xui8>
%27 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<16x194x199x4xf32>) -> tensor<*xui8>
// CHECK: %[[CAST_RESULT_1:.*]] = "tf.Cast"(%arg0)
// CHECK-SAME: (tensor<16x194x199x4xf32>) -> tensor<16x194x199x4xi8>
// CHECK: %[[CAST_RESULT_2:.*]] = "tf.Cast"(%[[CAST_RESULT_1]])
// CHECK-SAME: (tensor<16x194x199x4xi8>) -> tensor<*xi8>
%28 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<16x194x199x4xf32>) -> tensor<*xi8>
// CHECK: %[[ADDI:.*]] = addi %[[CAST_RESULT_2]], %[[CAST_RESULT_2]]
%2 = addi %28, %28 : tensor<*xi8>
// CHECK: return %[[CAST_RESULT_0]], %[[CAST_RESULT_1]], %[[ADDI]]
return %27, %28, %2 : tensor<*xui8>, tensor<*xi8>, tensor<*xi8>
}
}

View File

@ -0,0 +1,11 @@
// RUN: tf-opt %s -tf-standard-pipeline | FileCheck %s
// CHECK-LABEL: removeDeadReadVariableOp
func @removeDeadReadVariableOp(%arg0: tensor<!tf.resource<tensor<f32>>>) -> tensor<f32> {
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
%1 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
return %0: tensor<f32>
// CHECK: %0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
// CHECK: return %0
}

View File

@ -0,0 +1,250 @@
// RUN: tf-opt %s -tf-tpu-outside-compilation-cluster | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: func @one_cluster_no_dependencies
func @one_cluster_no_dependencies() {
// CHECK: "tf.opA"
// CHECK: "tf.opB"
// CHECK-SAME: _xla_outside_compilation = "{{[a-zA-Z_0-9]+}}"
// CHECK: "tf.opC"
"tf_device.cluster"() ( {
"tf.opA"() : () -> ()
"tf.opB"() {_xla_outside_compilation = "0"} : () -> ()
"tf.opC"() : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// CHECK-LABEL: func @one_cluster_with_one_op
func @one_cluster_with_one_op() {
// CHECK: "tf.opA"
// CHECK-NEXT: "tf.opB"
// CHECK-SAME: _xla_outside_compilation = "{{[a-zA-Z_0-9]+}}"
// CHECK-NEXT: "tf.opC"
"tf_device.cluster"() ( {
%a = "tf.opA"() : () -> tensor<i32>
%b = "tf.opB"(%a) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
"tf.opC"(%b) : (tensor<i32>) -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// CHECK-LABEL: func @one_cluster_with_two_ops
func @one_cluster_with_two_ops() {
// CHECK: "tf.opA"
// CHECK-NEXT: "tf.opB"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER2:[a-zA-Z_0-9]+]]"
// CHECK-NEXT: "tf.opC"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER2]]"
// CHECK-NEXT: "tf.opD"
"tf_device.cluster"() ( {
%a = "tf.opA"() : () -> tensor<i32>
%b = "tf.opB"(%a) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
%c = "tf.opC"(%b) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
"tf.opD"(%c) : (tensor<i32>) -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// CHECK-LABEL: func @one_cluster_with_three_ops
func @one_cluster_with_three_ops() {
// CHECK: "tf.opA"
// CHECK: "tf.opB"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER3:[a-zA-Z_0-9]+]]"
// CHECK: "tf.opC"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER3]]"
// CHECK: "tf.opD"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER3]]"
// CHECK: "tf.opE"
"tf_device.cluster"() ( {
%a = "tf.opA"() : () -> tensor<i32>
%b = "tf.opB"(%a) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
%c = "tf.opC"(%b) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
%d = "tf.opD"(%b, %c) {_xla_outside_compilation = "0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
"tf.opE"(%d) : (tensor<i32>) -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// CHECK-LABEL: func @two_clusters_no_dependencies
func @two_clusters_no_dependencies() {
// CHECK: "tf.opA"
// CHECK: "tf.opB"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER4:[a-zA-Z_0-9]+]]"
// CHECK: "tf.opC"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER4]]"
// CHECK: "tf.opD"
"tf_device.cluster"() ( {
"tf.opA"() : () -> ()
"tf.opB"() {_xla_outside_compilation = "0"} : () -> ()
"tf.opC"() {_xla_outside_compilation = "0"} : () -> ()
"tf.opD"() : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// CHECK-LABEL: func @two_clusters_with_one_op_each
func @two_clusters_with_one_op_each() {
// CHECK: "tf.opA"
// CHECK-NEXT: "tf.opB"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER6:[a-zA-Z_0-9]+]]"
// CHECK-NEXT: "tf.opC"
// CHECK-NEXT: "tf.opD"
// CHECK-NOT: _xla_outside_compilation = "[[CLUSTER6]]"
// CHECK-SAME: _xla_outside_compilation = "{{[a-zA-Z_0-9]+}}"
// CHECK-NEXT: "tf.opE"
"tf_device.cluster"() ( {
%a = "tf.opA"() : () -> tensor<i32>
%b = "tf.opB"(%a) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
%c = "tf.opC"(%b) : (tensor<i32>) -> tensor<i32>
%d = "tf.opD"(%c) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
"tf.opE"(%d) : (tensor<i32>) -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// CHECK-LABEL: func @two_clusters_with_two_ops_each
func @two_clusters_with_two_ops_each() {
// CHECK: "tf.opA"
// CHECK-NEXT: "tf.opB"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER8:[a-zA-Z_0-9]+]]"
// CHECK-NEXT: "tf.opC"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER8]]"
// CHECK-NEXT: "tf.opD"
// CHECK-NEXT: "tf.opE"
// CHECK-NOT: _xla_outside_compilation = "[[CLUSTER8]]"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER9:[a-zA-Z_0-9]+]]"
// CHECK-NEXT: "tf.opF"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER9]]"
// CHECK-NEXT: "tf.opG"
"tf_device.cluster"() ( {
%a = "tf.opA"() : () -> tensor<i32>
%b = "tf.opB"(%a) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
%c = "tf.opC"(%b) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
%d = "tf.opD"(%c) : (tensor<i32>) -> tensor<i32>
%e = "tf.opE"(%d) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
%f = "tf.opF"(%e) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
"tf.opG"(%f) : (tensor<i32>) -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// CHECK-LABEL: func @two_clusters_with_same_parent
func @two_clusters_with_same_parent() {
// CHECK: "tf.opA"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER10:[a-zA-Z_0-9]+]]"
// CHECK-NEXT: "tf.opB"
// CHECK-NEXT: "tf.opC"
// CHECK-NOT: _xla_outside_compilation = "[[CLUSTER10]]"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER11:[a-zA-Z_0-9]+]]"
// CHECK-NEXT: "tf.opD"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER10]]"
// CHECK-NEXT: "tf.opE"
// CHECK-NEXT: "tf.opF"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER11]]"
// CHECK-NEXT: "tf.opG"
"tf_device.cluster"() ( {
%a = "tf.opA"() {_xla_outside_compilation = "0"} : () -> tensor<i32>
%b = "tf.opB"(%a) : (tensor<i32>) -> tensor<i32>
%c = "tf.opC"(%b) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
%d = "tf.opD"() {_xla_outside_compilation = "0"} : () -> tensor<i32>
%e = "tf.opE"(%d) : (tensor<i32>) -> tensor<i32>
%f = "tf.opF"(%e) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
%g = "tf.opG"(%c, %f) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// CHECK-LABEL: func @two_clusters_with_same_outside_compiled_parent
func @two_clusters_with_same_outside_compiled_parent() {
// CHECK: "tf.opA"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER12:[a-zA-Z_0-9]+]]"
// CHECK-NEXT: "tf.opB"
// CHECK-NEXT: "tf.opC"
// CHECK-NOT: _xla_outside_compilation = "[[CLUSTER12]]"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER13:[a-zA-Z_0-9]+]]"
// CHECK-NEXT: "tf.opD"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER12]]"
// CHECK-NEXT: "tf.opE"
// CHECK-NEXT: "tf.opF"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER13]]"
// CHECK-NEXT: "tf.opG"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER13]]"
"tf_device.cluster"() ( {
%a = "tf.opA"() {_xla_outside_compilation = "0"} : () -> tensor<i32>
%b = "tf.opB"(%a) : (tensor<i32>) -> tensor<i32>
%c = "tf.opC"(%b) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
%d = "tf.opD"() {_xla_outside_compilation = "0"} : () -> tensor<i32>
%e = "tf.opE"(%d) : (tensor<i32>) -> tensor<i32>
%f = "tf.opF"(%e) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
%g = "tf.opG"(%c, %f) {_xla_outside_compilation = "0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// CHECK-LABEL: func @parent_with_a_non_outside_compiled_child
func @parent_with_a_non_outside_compiled_child() {
// CHECK: "tf.opA"
// CHECK-NEXT: "tf.opB"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER14:[a-zA-Z_0-9]+]]"
// CHECK-NEXT: "tf.opC"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER14]]"
"tf_device.cluster"() ( {
%a = "tf.opA"() : () -> tensor<i32>
%b = "tf.opB"() {_xla_outside_compilation = "0"} : () -> tensor<i32>
%c = "tf.opC"(%a, %b) {_xla_outside_compilation = "0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// CHECK-LABEL: func @outside_compile_with_block
func @outside_compile_with_block() {
// CHECK: "tf.opA"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER15:[a-zA-Z_0-9]+]]"
// CHECK-NEXT: "tf.opB"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER15]]"
// CHECK: "tf.opC"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER15]]"
"tf_device.cluster"() ( {
%a = "tf.opA"() {_xla_outside_compilation = "0"} : () -> tensor<i32>
%b = "tf.opB"() {_xla_outside_compilation = "0"} : () -> tensor<i32>
"tf_device.cluster" () ( {
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
%c = "tf.opC"() {_xla_outside_compilation = "0"} : () -> tensor<i32>
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// CHECK-LABEL: func @two_clusters_with_one_op_each_with_indirect_dependency
func @two_clusters_with_one_op_each_with_indirect_dependency() {
// CHECK: "tf.opA"
// CHECK-NEXT: "tf.opB"
// CHECK-SAME: _xla_outside_compilation = "[[CLUSTER16:[a-zA-Z_0-9]+]]"
// CHECK-NEXT: "tf.opC"
// CHECK-NEXT: "tf.opD"
// CHECK-NEXT: "tf.opE"
// CHECK-NOT: _xla_outside_compilation = "[[CLUSTER16]]"
// CHECK-SAME: _xla_outside_compilation = "{{[a-zA-Z_0-9]+}}"
// CHECK-NEXT: "tf.opF"
"tf_device.cluster"() ( {
%a = "tf.opA"() : () -> tensor<i32>
%b = "tf.opB"(%a) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
%c = "tf.opC"(%b) : (tensor<i32>) -> tensor<i32>
%d = "tf.opD"(%c) : (tensor<i32>) -> tensor<i32>
%e = "tf.opE"(%d) {_xla_outside_compilation = "0"} : (tensor<i32>) -> tensor<i32>
"tf.opF"(%e) : (tensor<i32>) -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}

View File

@ -33,7 +33,7 @@ namespace TFDevice {
namespace {
constexpr char kReplicationAttr[] = "tf_device.is_same_data_across_replicas";
constexpr char kReplicationAttr[] = "xla_hlo.is_same_data_across_replicas";
constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
// Analyzes the inputs to ClusterFuncOps in the module, and annotates their
@ -83,8 +83,7 @@ void AnnotateParameterReplication::runOnOperation() {
// Not a replication-invariant operand.
continue;
}
func.setArgAttr(entry.index(), kReplicationAttr,
builder.getBoolAttr(true));
func.setArgAttr(entry.index(), kReplicationAttr, builder.getUnitAttr());
}
});
}

View File

@ -65,6 +65,12 @@ def BatchMatMulV2ToMatMul : Pat<(TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y),
(TF_MatMulOp $x, $y, $adj_x, $adj_y),
[(IsRank2Tensor $x), (IsRank2Tensor $y)]>;
//===----------------------------------------------------------------------===//
// BiasAddV1 op patterns.
//===----------------------------------------------------------------------===//
def BiasAddV1ToBiasAdd : Pat<(TF_BiasAddV1Op $arg0, $arg1),
(TF_BiasAddOp $arg0, $arg1, ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">)>;
//===----------------------------------------------------------------------===//
// Bitcast op patterns.

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
@ -159,6 +160,44 @@ Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
return reshape.getResult();
}
// Returns true if broadcast_dimensions obey Tensorflow convention, as in new
// dimensions are added as prefix.
bool IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,
Value output) {
// broadcast_dimensions is an increasing list by definition, thus it suffices
// to check the first element.
int64_t input_rank = broadcast_dimensions.getNumElements();
int64_t output_rank = output.getType().cast<ShapedType>().getRank();
return input_rank == 0 ||
(broadcast_dimensions.getValue({0}).cast<IntegerAttr>().getInt() ==
output_rank - input_rank);
}
// Returns the intermediate shape that input tensor should be reshaped to during
// legalization of BroadcastInDimOp.
ConstantOp ExpandedShape(PatternRewriter &rewriter, Value input,
DenseIntElementsAttr broadcast_dimensions,
Value output) {
// Initialize expanded shape with output rank and dimensions of 1.
SmallVector<Attribute, 4> expanded_shape(
output.getType().cast<ShapedType>().getRank(),
/*Value=*/rewriter.getI64IntegerAttr(1));
// Set dimension sizes specified by broadcast_dimensions.
ArrayRef<int64_t> input_shape = input.getType().cast<ShapedType>().getShape();
for (auto x : llvm::enumerate(broadcast_dimensions)) {
expanded_shape[x.value().getSExtValue()] =
rewriter.getI64IntegerAttr(input_shape[x.index()]);
}
// Create the expanded type wrapped in a ConstantOp.
auto attr_type =
RankedTensorType::get({static_cast<int64_t>(expanded_shape.size())},
rewriter.getIntegerType(64));
auto attr = DenseElementsAttr::get(attr_type, expanded_shape);
return rewriter.create<ConstantOp>(output.getLoc(), attr_type, attr);
}
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc"
/// Performs the lowering to XLA dialect.

View File

@ -29,6 +29,17 @@ def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">,
// Return a constant op that carries the shape of the given value.
def ShapeToConst : NativeCodeCall<"ShapeToConst($_builder, $0)">;
// Check if broadcast dimensions match Tensorflow convention.
def IsTFStyleBroadcast : Constraint<CPred<"IsTFStyleBroadcast($0, $1)">,
"new dimensions are added as prefix">;
// Check if broadcast dimensions do not match Tensorflow convention.
def IsNotTFStyleBroadcast : Constraint<Neg<CPred<"IsTFStyleBroadcast($0, $1)">>,
"new dimensions are inserted in intermediate positions">;
// Return intermediate shape before broadcasting, wrapped in a constant op.
def ExpandedShape : NativeCodeCall<"ExpandedShape($_builder, $0, $1, $2)">;
def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>;
//===----------------------------------------------------------------------===//
@ -38,6 +49,7 @@ def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>;
// context, getting to these ops may require some raising.
//===----------------------------------------------------------------------===//
// TODO(b/158025719): Properly handle broadcast_dimensions.
foreach fromToBinPair = [[HLO_AddOp, HLOClient_BroadcastAddOp, TF_AddV2Op],
[HLO_DivOp, HLOClient_BroadcastDivOp, TF_DivOp],
[HLO_ShiftLeftOp, HLOClient_BroadcastShiftLeftOp, TF_LeftShiftOp],
@ -80,6 +92,7 @@ def : Pat<(HLO_FloorOp (HLOClient_BroadcastDivOp $l, $r, $_)), (TF_FloorDivOp $l
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>;
//===----------------------------------------------------------------------===//
// Unary op patterns.
//===----------------------------------------------------------------------===//
@ -112,6 +125,16 @@ def : Pat<(HLO_AbsOp TF_ComplexTensor:$arg), (TF_ComplexAbsOp $arg)>;
def : Pat<(HLO_BroadcastOp $arg, $shape),
(TF_BroadcastToOp $arg, (TF_ConstOp $shape))>;
def : Pat<(HLO_BroadcastInDimOp:$output $input, $broadcast_dimensions),
(TF_BroadcastToOp $input, (ShapeToConst $output)),
[(IsTFStyleBroadcast $broadcast_dimensions, $output)]>;
def : Pat<(HLO_BroadcastInDimOp:$output $input, $broadcast_dimensions),
(TF_BroadcastToOp
(TF_ReshapeOp
$input,
(ExpandedShape $input, $broadcast_dimensions, $output)),
(ShapeToConst $output)),
[(IsNotTFStyleBroadcast $broadcast_dimensions, $output)]>;
def : Pat<(HLO_TransposeOp $arg, $permutation),
(TF_TransposeOp $arg, (TF_ConstOp $permutation))>;
def : Pat<(HLO_ReverseOp $op, $dims), (TF_ReverseV2Op $op, (TF_ConstOp $dims))>;

View File

@ -0,0 +1,174 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdio>
#include <iostream>
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TF {
namespace {
// Note: This implements the fusions performed in the old Remapper Grappler
// pass. That pass has specific cases for GPU and based on different
// target configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR
// pass covers the general CPU case and at the moment does not account for any
// specific target configurations.
// TODO(b/158265178): Support GPU-specific fusions.
// TODO(b/158266710): Support CPU MKL configurations.
// Optimizes TF computations by fusing subgraphs/nodes onto more efficient
// implementations to decrease the number of operations needed to perform a
// computation.
struct OpFusionPass : public PassWrapper<OpFusionPass, FunctionPass> {
void runOnFunction() override;
};
// Returns an op's name with the dialect prefix stripped off.
StringRef GetOpNameWithoutDialect(Operation *op) {
return op->getName().getStringRef().split(".").second;
}
bool IsActivationFunction(Operation *op) {
return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op);
}
// Finds and returns an activation op that uses the result of `op`. If there are
// multiple such activations, one is returned (with no guarantee as to which
// one). If there are no activation functions that use the output, returns
// nullptr.
Operation *GetActivation(Value op) {
for (auto &use : op.getUses()) {
if (IsActivationFunction(use.getOwner())) return use.getOwner();
}
return nullptr;
}
// Finds and returns a BiasAdd that uses the result of `op` as the `value`
// input. If there are multiple such BiasAdds, one is returned (with no
// guarantee as to which one). If there are no BiasAdds that use the output,
// returns a null BiasAddOp.
BiasAddOp GetBiasAdd(Value op) {
for (auto &use : op.getUses()) {
auto bias_add = dyn_cast_or_null<BiasAddOp>(use.getOwner());
// If it's a BiasAdd, check that the conv op is the first input.
if (bias_add && bias_add.value() == op) return bias_add;
}
// No BiasAddOps found among uses.
return BiasAddOp();
}
// Performs a fusion of the following pattern(s), if possible:
// Conv2D + BiasAdd + <Activation> -> _FusedConv2D
//
// Note that fusion with activation is preferred, but a Conv2D and BiasAdd can
// also be replaced by a _FusedConv2D if there is no other activation function.
// i.e., this class also supports the following fusion:
// Conv2D + BiasAdd -> _FusedConv2D
//
// TODO(b/158266331): Support fusing Conv2D + BiasAdd + a chain of activations.
class FuseConv2DBiasAdd : public OpRewritePattern<Conv2DOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Conv2DOp op,
PatternRewriter &rewriter) const override {
// If the convolution is used in multiple places, fusing it will only create
// more convolutions, which is slower.
if (!op.getResult().hasOneUse())
return rewriter.notifyMatchFailure(op, "result is used by multiple ops");
BiasAddOp bias_add = GetBiasAdd(op);
if (!bias_add) {
return rewriter.notifyMatchFailure(
op, "does not feed into a tf.BiasAdd/tf.BiasAddV1 op");
}
// Check that Conv and BiasAdd formats match.
if (op.data_format() != bias_add.data_format()) {
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "data format does not match Conv2D data format ("
<< bias_add.data_format() << " vs " << op.data_format() << ")";
});
}
SmallVector<Location, 3> locations{op.getLoc(), bias_add.getLoc()};
SmallVector<Attribute, 2> fused_ops{StringAttr::get(
GetOpNameWithoutDialect(bias_add), rewriter.getContext())};
Type result_type;
// BiasAdd may or may not feed into an activation function.
auto activation = GetActivation(bias_add);
// If there is an activation, only fuse it if this is the only op to use the
// result of the BiasAdd.
bool fuse_activation = activation && bias_add.output().hasOneUse();
// Include info about the activation function if applicable.
if (fuse_activation) {
locations.push_back(activation->getLoc());
fused_ops.push_back(StringAttr::get(GetOpNameWithoutDialect(activation),
rewriter.getContext()));
result_type = activation->getResultTypes().front();
} else {
result_type = bias_add.getResult().getType();
}
auto loc = rewriter.getFusedLoc(locations);
ArrayAttr fused_ops_attr = ArrayAttr::get(fused_ops, rewriter.getContext());
// Epsilon is used only in fusions with the BatchNorm op.
APFloat epsilon = APFloat(0.0f);
auto fused_op = rewriter.create<_FusedConv2DOp>(
loc, result_type, op.input(), op.filter(), bias_add.bias(),
op.strides(), op.padding(), op.explicit_paddings(), op.data_format(),
op.dilations(), op.use_cudnn_on_gpu(), fused_ops_attr, epsilon);
auto op_to_replace = fuse_activation ? activation : bias_add;
rewriter.replaceOp(op_to_replace, {fused_op});
return success();
}
};
void OpFusionPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
patterns.insert<FuseConv2DBiasAdd>(&getContext());
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass() {
return std::make_unique<OpFusionPass>();
}
static PassRegistration<OpFusionPass> pass(
"op-fusion",
"Replaces commonly occurring subgraphs with optimized fused kernels");
} // namespace TF
} // namespace mlir

View File

@ -52,6 +52,10 @@ std::unique_ptr<OperationPass<FuncOp>> CreateBatchMatMulToEinsumPass();
// Optimizes Tensorflow graph.
std::unique_ptr<OperationPass<FuncOp>> CreateTFOptimizePass();
// Creates pass to rewrite RecvTPUEmbeddingActivationsOp and
// SendTPUEmbeddingGradients ops to internal variants.
std::unique_ptr<OperationPass<FuncOp>> CreateRewriteTPUEmbeddingOps();
// Performs specific fusion for GPU targets.
std::unique_ptr<OperationPass<FuncOp>> CreateGpuOpFusionPass();
@ -140,6 +144,9 @@ CreateTensorArrayOpsDecompositionPass();
// Create a pass that legalize HLO to TF dialect.
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
// Creates a pass that performs fusion of common sequences of ops.
std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass();
} // namespace TF
namespace TFControlFlow {
@ -265,6 +272,10 @@ std::unique_ptr<OperationPass<FuncOp>> CreateTPUMergeVariablesWithExecutePass();
// run-time according to compilation result.
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass();
// Creates a pass that groups outside compiled operations (CPU ops inside TPU
// cluster) into clusters that can be extracted and run on the CPU.
std::unique_ptr<OperationPass<FuncOp>> CreateTPUOutsideCompilationClusterPass();
// Creates a pass that extracts outside compilation (CPU ops inside TPU cluster)
// at head/tail of TPU cluster to run before/after TPU computation.
std::unique_ptr<OperationPass<ModuleOp>>

View File

@ -0,0 +1,114 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TF {
namespace {
// Rewrites RecvTPUEmbeddingActivationsOp and SendTPUEmbeddingGradients ops to
// internal variants by introducing _RecvTPUEmbeddingDeduplicationData op.
struct RewriteTPUEmbeddingOps
: public PassWrapper<RewriteTPUEmbeddingOps, FunctionPass> {
void runOnFunction() override;
};
// Rewrites the given op to `OpT` op after adding the given operand at the end.
template <typename OpT>
OpT AddOperandAndRewriteAs(Operation* op, Value operand, OpBuilder* builder) {
builder->setInsertionPoint(op);
auto operands = llvm::to_vector<4>(op->getOperands());
operands.push_back(operand);
auto new_op = builder->create<OpT>(op->getLoc(), op->getResultTypes(),
operands, op->getAttrs());
op->replaceAllUsesWith(new_op.getOperation()->getResults());
op->erase();
return new_op;
}
// Returns success if the function has at most one op of the template type and
// assigns it to `result`, if present. If there are multiple such ops, returns
// failure.
template <typename OpT>
LogicalResult GetOp(FuncOp func, OpT* result) {
*result = {};
for (auto op : func.getOps<OpT>()) {
if (*result) return op.emitError("should be unique within a function");
*result = op;
}
return success();
}
void RewriteTPUEmbeddingOps::runOnFunction() {
FuncOp func = getFunction();
RecvTPUEmbeddingActivationsOp recv_op;
if (failed(GetOp(func, &recv_op))) return signalPassFailure();
SendTPUEmbeddingGradientsOp send_op;
if (failed(GetOp(func, &send_op))) return signalPassFailure();
// No TPU embedding ops.
if (!recv_op && !send_op) return;
Location loc = recv_op ? recv_op.getLoc() : send_op.getLoc();
StringRef config = recv_op ? recv_op.config() : send_op.config();
// Create _RecvTPUEmbeddingDeduplicationData op.
OpBuilder builder(func.getBody());
auto output_ty = RankedTensorType::get({}, VariantType::get(&getContext()));
auto dedup_op = builder.create<_RecvTPUEmbeddingDeduplicationDataOp>(
loc, output_ty, config);
// Rewrite RecvTPUEmbeddingActivations op to the corresponding internal op.
if (recv_op)
AddOperandAndRewriteAs<_RecvTPUEmbeddingActivationsOp>(recv_op, dedup_op,
&builder);
// Rewrite SendTPUEmbeddingGradients op to the corresponding internal op and
// then update the OperandSegmentSize attribute.
if (send_op) {
int32_t operand_sizes[] = {static_cast<int32_t>(send_op.N()),
static_cast<int32_t>(send_op.NN()), 1};
auto attr_ty = VectorType::get(3, builder.getI32Type());
auto operand_size_attr = DenseIntElementsAttr::get(attr_ty, operand_sizes);
auto new_send_op = AddOperandAndRewriteAs<_SendTPUEmbeddingGradientsOp>(
send_op, dedup_op, &builder);
new_send_op.setAttr(new_send_op.getOperandSegmentSizeAttr(),
operand_size_attr);
}
}
} // anonymous namespace
std::unique_ptr<OperationPass<FuncOp>> CreateRewriteTPUEmbeddingOps() {
return std::make_unique<RewriteTPUEmbeddingOps>();
}
static PassRegistration<RewriteTPUEmbeddingOps> pass(
"tf-rewrite-tpu-embedding-ops",
"Rewrites TPU embedding send/recv ops by adding TPU embedding "
"deduplication data");
} // namespace TF
} // namespace mlir

View File

@ -114,12 +114,22 @@ Optional<SmallVector<Type, 4>> InferShapeForFunctionReturnType(FuncOp func) {
// Returns if the shape inference pass supports an op outside the TF dialect.
bool IsSupportedNonTFOp(Operation* op) {
return isa<tf_executor::YieldOp>(op) || isa<tf_executor::IslandOp>(op) ||
return isa<ReturnOp>(op) || isa<tf_device::ReturnOp>(op) ||
isa<tf_executor::EnterOp>(op) || isa<tf_executor::ExitOp>(op) ||
isa<tf_executor::FetchOp>(op) || isa<tf_executor::GraphOp>(op) ||
isa<tf_executor::NextIterationSinkOp>(op) || isa<ReturnOp>(op) ||
isa<tf_device::ReturnOp>(op) || isa<tf_executor::MergeOp>(op) ||
isa<tf_executor::SwitchOp>(op) || isa<tf_executor::SwitchNOp>(op) ||
isa<tf_executor::EnterOp>(op) || isa<tf_executor::ExitOp>(op);
isa<tf_executor::IslandOp>(op) || isa<tf_executor::LoopCondOp>(op) ||
isa<tf_executor::MergeOp>(op) ||
isa<tf_executor::NextIterationSinkOp>(op) ||
isa<tf_executor::SwitchNOp>(op) || isa<tf_executor::SwitchOp>(op) ||
isa<tf_executor::YieldOp>(op);
}
// Returns whether a cast back would need to be inserted, e.g., whether the
// operation of which use is an operand allows for shape refinement without
// a cast.
bool NeedsCastBack(OpOperand& use, Dialect* tf_dialect) {
return use.getOwner()->getDialect() != tf_dialect &&
!IsSupportedNonTFOp(use.getOwner());
}
// Inserts tf.Cast operation when changing the type of a result if the user is
@ -139,9 +149,7 @@ void AddCastBackForUnsupportedNonTFUses(Operation* op, Value result,
return Value(cast_op);
};
for (OpOperand& use : make_early_inc_range(result.getUses())) {
if (use.getOwner()->getDialect() != tf_dialect &&
!IsSupportedNonTFOp(use.getOwner()))
use.set(get_cast_op());
if (NeedsCastBack(use, tf_dialect)) use.set(get_cast_op());
}
}
@ -316,6 +324,34 @@ bool InferShapeForCall(Operation* op) {
return changed;
}
bool InferShapeForCast(CastOp op, Dialect* tf_dialect) {
Value result = op.getResult();
if (!CanBeRefined(result.getType())) return false;
Type operand_type = op.getOperand().getType();
auto ranked_op_type = operand_type.dyn_cast<RankedTensorType>();
if (!ranked_op_type) return false;
auto ranked_res_type = result.getType().dyn_cast<RankedTensorType>();
if (ranked_res_type &&
ranked_op_type.getShape() == ranked_res_type.getShape())
return false;
// Avoid inserting a cast where no users types could be refined (e.g., where
// there would need to be a cast inserted for every user again).
if (llvm::all_of(result.getUses(), [tf_dialect](OpOperand& use) {
return NeedsCastBack(use, tf_dialect);
}))
return false;
auto new_type = RankedTensorType::get(
ranked_op_type.getShape(),
result.getType().cast<ShapedType>().getElementType());
auto old_type = result.getType();
result.setType(new_type);
AddCastBackForUnsupportedNonTFUses(op, op.getResult(), tf_dialect, old_type);
return true;
}
bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti,
Dialect* tf_dialect) {
Operation* op = infer_ti.getOperation();
@ -670,18 +706,11 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
if (isa<PartitionedCallOp>(op) || isa<StatefulPartitionedCallOp>(op))
return InferShapeForCall(op);
// tf.Cast are only inferred if they have at least one user in the tf dialect.
// This is necessary to avoid reprocessing the tf.Cast that are inserted at
// the end of this function.
if (isa<CastOp>(op) &&
all_of(op->getResult(0).getUsers(), [&](Operation* user) {
return user->getDialect() != tf_dialect_;
})) {
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF "
"dialect operation users '"
<< *op << "'.\n");
return false;
}
// tf.Cast are only inferred if they have at least one user in the TF dialect
// or feeding into the function return. This is necessary to avoid inserting
// casts which cannot be refined.
if (auto cast_op = dyn_cast<CastOp>(op))
return InferShapeForCast(cast_op, tf_dialect_);
StringRef op_name = op->getName().getStringRef();
// Drop the `tf.` prefix to query TF registry.
@ -857,7 +886,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
auto new_type = get_tensor_type(shape_handle, new_element_type);
if (result.getType() == new_type) continue;
// Inserts a cast back to the original type if any user is not in the TF
// dialect.
// dialect or a return.
AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect_,
result.getType());
// Finally we inferred the shape and replace the type for this result.

View File

@ -0,0 +1,131 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace TFTPU {
namespace {
constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
struct TPUOutsideCompilationCluster
: public PassWrapper<TPUOutsideCompilationCluster, FunctionPass> {
void runOnFunction() override;
};
// Represents an outside compiled cluster. All ops that are added to the same
// cluster will be extracted together in a later pass.
class OutsideCompiledCluster {
public:
explicit OutsideCompiledCluster(int number)
: cluster_name_(llvm::formatv("cluster{0}", number).str()) {}
// Attempts to add an op to this cluster.
// This function requires all ops to be added before their uses.
bool AddOp(Operation* op) {
// Check if the op is safe to add before adding it.
bool add = IsSafeToAdd(op);
if (add) {
// Set the ops kXlaOutsideCompilationAttr to the cluster name.
op->setAttr(kXlaOutsideCompilationAttr,
StringAttr::get(cluster_name_, op->getContext()));
// Since we are adding the op to the cluster, the op is no longer
// considered a user of this cluster.
users_.erase(op);
}
// Add this op's users to the cluster users.
users_.insert(op->user_begin(), op->user_end());
return add;
}
private:
// Checks if it is safe for an op to be merged into this cluster.
bool IsSafeToAdd(Operation* op) {
// If the op is not marked for outside compilation it doesn't belong in a
// cluster.
if (!op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
return false;
// Checks to see if the op's operands are related to this
// clusters users. If they are related, then there is an op between this
// op and the cluster. Since ops are added before their uses, there
// is no way for the op in-between to ever be added to this cluster
// therefore there is no way this op can ever be added to the cluster.
for (const Value& value : op->getOperands()) {
Operation* op_operand = value.getDefiningOp();
if (op_operand && users_.find(op_operand) != users_.end()) return false;
}
return true;
}
// users_ stores the direct and indirect users of the outside compiled ops in
// this cluster. It does NOT store the outside compiled ops that are a part
// of this cluster that will be collectively extracted and run on the cpu.
// users_ is consulted when attempting to add a new outside compiled to the
// cluster. If the new op's operand(s) are already in users_, it means that
// the operand(s) were not added to the cluster so it is not safe to add the
// new op to the cluster either.
llvm::SmallPtrSet<Operation*, 8> users_;
std::string cluster_name_;
};
void TPUOutsideCompilationCluster::runOnFunction() {
llvm::SmallVector<OutsideCompiledCluster, 8> clusters;
int cluster_counter = 0;
getFunction().walk([&](tf_device::ClusterOp tpu_cluster) {
for (Operation& op : tpu_cluster.GetBody()) {
// Try to add the op to existing clusters.
bool added = false;
for (auto& cluster : clusters)
if ((added = cluster.AddOp(&op))) break;
// If the op cannot be added to existing clusters, create a new cluster.
if (!added) {
OutsideCompiledCluster new_cluster(cluster_counter++);
new_cluster.AddOp(&op);
clusters.push_back(new_cluster);
}
}
});
}
} // anonymous namespace
std::unique_ptr<OperationPass<FuncOp>>
CreateTPUOutsideCompilationClusterPass() {
return std::make_unique<TPUOutsideCompilationCluster>();
}
static PassRegistration<TPUOutsideCompilationCluster> pass(
"tf-tpu-outside-compilation-cluster",
"Identifies clusters of operations assigned to outside compilation");
} // namespace TFTPU
} // namespace mlir

View File

@ -219,7 +219,7 @@ void BreakUpIslands::BreakUpIsland(
}
// Skip islands that are already only a single op.
if (hasSingleElement(island_body)) return;
if (island_op.WrapsSingleOp()) return;
auto control_type = tf_executor::ControlType::get(&getContext());
auto island_control_inputs = llvm::to_vector<4>(island_op.controlInputs());

View File

@ -3572,6 +3572,24 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
/*func_name=*/"main");
}
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
mlir::StringRef name, const FunctionLibraryDefinition& flib_def,
mlir::MLIRContext* context) {
const tensorflow::FunctionDef* fdef = flib_def.Find(name.str());
if (fdef == nullptr)
return tensorflow::errors::NotFound("Cannot find function ", name.str());
std::unique_ptr<tensorflow::FunctionBody> fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(),
&flib_def, &fbody));
tensorflow::GraphDebugInfo dummy_debug_info;
tensorflow::GraphImportConfig specs;
specs.graph_as_function = true;
return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
flib_def, specs, name);
}
StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
absl::Span<std::string> exported_names, bool add_default_attributes) {

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
@ -45,6 +46,13 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
mlir::MLIRContext* context);
// [Experimental]
// Given a Function, returns a MLIR module containing the graph, expressed with
// tf_executor dialect.
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
mlir::StringRef name, const FunctionLibraryDefinition& flib_def,
mlir::MLIRContext* context);
// Given a SavedModel, returns a MLIR module containing the functions, expressed
// with tf_executor dialect.
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(

View File

@ -184,7 +184,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
// only be lowered when tf.Shape is folded into a constant.
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {tf_device.is_same_data_across_replicas = true}) -> tensor<10x19xf32> {
func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {xla_hlo.is_same_data_across_replicas}) -> tensor<10x19xf32> {
%0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64>
%1 = "tf.Reshape"(%arg1, %0) : (tensor<19x10xf32>, tensor<2xi64>) -> tensor<10x19xf32>
return %1 : tensor<10x19xf32>

View File

@ -379,13 +379,13 @@ Status ConvertAttributes(
func_call_attrs[string(name)] = value;
continue;
}
case mlir::StandardAttributes::Bool:
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::BoolAttr>(), &value));
break;
case mlir::StandardAttributes::Integer:
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::IntegerAttr>(), &value));
if (auto boolAttr = attr.dyn_cast<mlir::BoolAttr>()) {
TF_RETURN_IF_ERROR(ConvertAttribute(boolAttr, &value));
} else {
TF_RETURN_IF_ERROR(
ConvertAttribute(attr.cast<mlir::IntegerAttr>(), &value));
}
break;
case mlir::StandardAttributes::Float:
TF_RETURN_IF_ERROR(

View File

@ -21,6 +21,7 @@ cc_library(
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TargetNVVMIR",
"@llvm-project//mlir:Transforms",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:lhlo",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
@ -42,6 +43,7 @@ limitations under the License.
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Target/NVVMIR.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
@ -216,14 +218,22 @@ Status PropagateStaticShapeKnowledgeToKernel(
}
return Status::OK();
}
void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
return true;
}();
(void)init_once;
}
} // namespace
StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
llvm::StringRef tf_code, std::pair<int32_t, int32_t> compute_capability,
llvm::ArrayRef<uint32_t> tile_sizes, llvm::ArrayRef<uint32_t> same_shape,
llvm::ArrayRef<uint32_t> unroll_factors) {
RegisterDialects();
mlir::MLIRContext context;
context.allowUnregisteredDialects(); // TODO(b/152572127)
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get()));

View File

@ -47,6 +47,7 @@ bool ParseStringList(std::string string_list, std::vector<uint32_t>* result) {
} // namespace
int main(int argc, char** argv) {
std::string input_file = "foo.mlir";
std::string output_file = "foo.bin";
int32_t architecture = 50;
std::vector<uint32_t> tile_sizes;
@ -75,6 +76,7 @@ int main(int argc, char** argv) {
};
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("input", &input_file, "input file"),
tensorflow::Flag("output", &output_file, "output file"),
tensorflow::Flag("arch", &architecture,
"target architecture (e.g. 50 for sm_50)"),
@ -94,8 +96,16 @@ int main(int argc, char** argv) {
std::pair<int32_t, int32_t> compute_capability(architecture / 10,
architecture % 10);
std::string tf_code;
auto read_status = tensorflow::ReadFileToString(tensorflow::Env::Default(),
input_file, &tf_code);
if (!read_status.ok()) {
LOG(ERROR) << read_status;
return 1;
}
auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode(
argv[1], compute_capability, tile_sizes, same_shape, unroll_factors);
tf_code, compute_capability, tile_sizes, same_shape, unroll_factors);
if (!cubin.ok()) {
LOG(ERROR) << cubin.status();

View File

@ -634,7 +634,7 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO
I32Attr:$index
);
let results = (outs HLO_TensorOrTuple);
let results = (outs HLO_TensorOrTokenOrTuple);
let hasFolder = 1;

View File

@ -62,6 +62,8 @@ def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>;
def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>;
// Dynamic representation of a shape vector as a tensor.

View File

@ -73,7 +73,8 @@ constexpr char kPaddingMapAttr[] = "xla_hlo.padding_map";
constexpr char kShapeIndicesAttr[] = "shape_indices";
constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
constexpr char kShardingAttr[] = "xla_hlo.sharding";
constexpr char kRepicationAttr[] = "tf_device.is_same_data_across_replicas";
constexpr char kFrontendAttributesAttr[] = "xla_hlo.frontend_attributes";
constexpr char kRepicationAttr[] = "xla_hlo.is_same_data_across_replicas";
// Passes through everything except for unique_ptr, on which it calls get().
// This exists to allow the generated code to call XLA functions that take a raw
@ -381,21 +382,41 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
return output;
}
// Extracts sharding from attribute string.
static absl::optional<xla::OpSharding> CreateOpShardingFromStringRef(
llvm::StringRef sharding) {
xla::OpSharding sharding_proto;
if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt;
return sharding_proto;
}
// Returns an OpSharding proto from the "sharding" attribute of the op. If the
// op doesn't have a sharding attribute or the sharding attribute is invalid,
// returns absl::nullopt.
static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
mlir::Operation* op) {
auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr);
if (!sharding) {
return absl::nullopt;
}
::xla::OpSharding sharding_proto;
if (!::tensorflow::protobuf::TextFormat::ParseFromString(
sharding.getValue().str(), &sharding_proto)) {
return absl::nullopt;
}
return sharding_proto;
if (!sharding) return absl::nullopt;
return CreateOpShardingFromStringRef(sharding.getValue());
}
// Returns a FrontendAttributes proto from the "frontend_attributes" attribute
// of the op. An empty FrontendAttributes proto is returned if an op does not
// have frontend attributes.
static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute(
mlir::Operation* op) {
xla::FrontendAttributes frontend_attributes;
auto frontend_attributes_dict =
op->getAttrOfType<mlir::DictionaryAttr>(kFrontendAttributesAttr);
if (!frontend_attributes_dict) return frontend_attributes;
for (const auto& attr : frontend_attributes_dict)
if (auto value_str_attr = attr.second.dyn_cast<mlir::StringAttr>())
frontend_attributes.mutable_map()->insert(
{attr.first.str(), value_str_attr.getValue().str()});
return frontend_attributes;
}
// Checks if all shardings are set.
@ -407,14 +428,6 @@ static bool AllOptionalShardingsAreSet(
});
}
// Extracts sharding from attribute string.
static absl::optional<xla::OpSharding> CreateOpShardingFromStringRef(
llvm::StringRef sharding) {
xla::OpSharding sharding_proto;
if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt;
return sharding_proto;
}
// Extracts argument and result shardings from function.
static void ExtractShardingsFromFunction(
mlir::FuncOp function,
@ -1144,8 +1157,8 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
bool any_arg_replicated = false;
entry_args_same_across_replicas.reserve(f.getNumArguments());
for (int64_t i = 0; i < f.getNumArguments(); ++i) {
auto attr = f.getArgAttrOfType<mlir::BoolAttr>(i, kRepicationAttr);
entry_args_same_across_replicas.push_back(attr && attr.getValue());
auto attr = f.getArgAttrOfType<mlir::UnitAttr>(i, kRepicationAttr);
entry_args_same_across_replicas.push_back(attr != nullptr);
any_arg_replicated |= entry_args_same_across_replicas.back();
// Pass the alias info to the builder so that it will build the alias info
// into the resulting HloModule.

View File

@ -138,6 +138,13 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
os << " xla::XlaScopedShardingAssignment sharding(lowering_context.builder, "
"CreateOpShardingFromAttribute(op));\n\n";
// Create a scoped object to assign frontend attributes to generated XLA ops.
// Any HLO can have an attribute of "frontend_attributes", which are used to
// pass hints / configuration options.
os << " xla::XlaScopedFrontendAttributesAssignment "
"frontend_attributes(lowering_context.builder, "
"CreateOpFrontendAttributesFromAttribute(op));\n\n";
// Retrieve all the definitions derived from HLO_Op and sort by record name.
for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
// Skip operations that have a custom exporter.

View File

@ -1,4 +1,4 @@
// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement %s -o - | FileCheck %s --dump-input-on-failure
// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
@ -157,13 +157,16 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
func @external_func() -> tensor<3xi64>
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
// CHECK-LABEL: func @dyn_broadcast
func @dyn_broadcast(%operand: memref<?x?xf32>) {
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
%tensor_operand = tensor_load %operand : memref<?x?xf32>
%shape = call @external_func() : () -> tensor<3xi64>
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape)
{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
: (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[SHAPE:.*]] = call @external_func()
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
@ -175,7 +178,33 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
// CHECK-NEXT: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %[[RESULT]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[C1__:.*]] = constant 1 : index
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], 0 : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[C2_:.*]] = constant 2 : index
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], 1 : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast
// CHECK-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
// CHECK-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// CHECK-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0>
// CHECK: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
// Do not store the value back to avoid the tensor-store being rewritten to
// a copy into the pre-allocated argument.
return

View File

@ -845,7 +845,19 @@ func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) {
func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> {
// CHECK: "xla_hlo.infeed"
// An additional sharding is added at the end to account for token result.
// CHECK-SAME: xla_hlo.sharding = "type: TUPLE\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0A"
// Proto debug string:
// type: TUPLE
// tuple_shardings {
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 0
// }
// tuple_shardings {
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 0
// }
// CHECK-SAME: xla_hlo.sharding = "\08\02*\08\08\01\1A\01\01\22\01\00*\08\08\01\1A\01\01\22\01\00"
%0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32>
return %0 : tensor<8xi32>
}
@ -1526,6 +1538,67 @@ func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex<f32>> {
return %0 : tensor<8xcomplex<f32>>
}
//===----------------------------------------------------------------------===//
// Shape op legalization.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @shape_1D
func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> {
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], 0
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]]) {dimension = 0 : i64}
%0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32>
// CHECK: return [[CONCAT]]
return %0 : tensor<1xi32>
}
// CHECK-LABEL: func @shape_2D
func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> {
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK-DAG: [[EXTENT0:%.+]] = shape.get_extent [[SHAPE]], 0
// CHECK-DAG: [[EXTENT1:%.+]] = shape.get_extent [[SHAPE]], 1
// CHECK-DAG: [[TO_INDEX0:%.+]] = shape.size_to_index [[EXTENT0]]
// CHECK-DAG: [[TO_INDEX1:%.+]] = shape.size_to_index [[EXTENT1]]
// CHECK-DAG: [[CAST0:%.+]] = index_cast [[TO_INDEX0]]
// CHECK-DAG: [[CAST1:%.+]] = index_cast [[TO_INDEX1]]
// CHECK-DAG: [[TENSOR0:%.+]] = tensor_from_elements([[CAST0]])
// CHECK-DAG: [[TENSOR1:%.+]] = tensor_from_elements([[CAST1]])
// CHECK-DAG: [[RESHAPE0:%.+]] = "xla_hlo.reshape"([[TENSOR0]])
// CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[TENSOR1]])
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE0]], [[RESHAPE1]]) {dimension = 0 : i64}
%0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32>
// CHECK: return [[CONCAT]]
return %0 : tensor<2xi32>
}
// CHECK-LABEL: func @shape_with_const
func @shape_with_const(%arg0: tensor<?x3xf32>) -> tensor<2xi32> {
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], 0
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
// CHECK-DAG: [[CONST:%.+]] = xla_hlo.constant dense<3>
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]], [[CONST]]) {dimension = 0 : i64}
%0 = "tf.Shape"(%arg0) : (tensor<?x3xf32>) -> tensor<2xi32>
// CHECK: return [[CONCAT]]
return %0 : tensor<2xi32>
}
// CHECK-LABEL: func @shape_rankless
func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> {
%0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
//===----------------------------------------------------------------------===//
// Transpose op legalization.
//===----------------------------------------------------------------------===//
@ -2746,8 +2819,8 @@ func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> {
// CHECK-LABEL: func @tile_by_reshape
func @tile_by_reshape(%arg0: tensor<4x8xf32>) -> tensor<28x24xf32> {
// CHECK: %[[BROADCASTED:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<4x8xf32>) -> tensor<4x7x8x3xf32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[BROADCASTED]]) : (tensor<4x7x8x3xf32>) -> tensor<28x24xf32>
// CHECK: %[[BROADCASTED:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<4x8xf32>) -> tensor<7x4x3x8xf32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[BROADCASTED]]) : (tensor<7x4x3x8xf32>) -> tensor<28x24xf32>
// CHECK: return %[[RESULT]] : tensor<28x24xf32>
%multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64>
%0 = "tf.Tile"(%arg0, %multiples) : (tensor<4x8xf32>, tensor<2xi64>) -> tensor<28x24xf32>
@ -2861,6 +2934,50 @@ func @range(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<5xf32> {
return %3 : tensor<5xf32>
}
// CHECK-LABEL: func @range_dynamic
// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[DELTA:%.*]]: tensor<f32>
func @range_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<?xf32> {
// CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract %arg1, %arg0
// CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"([[SUB]])
// CHECK-DAG: [[CONVERT1:%.+]] = "xla_hlo.convert"([[ABS1]])
// CHECK-DAG: [[CONVERT2:%.+]] = "xla_hlo.convert"(%arg2)
// CHECK-DAG: [[DIV:%.+]] = xla_hlo.divide [[CONVERT1]], [[CONVERT2]]
// CHECK-DAG: [[CEIL:%.+]] = "xla_hlo.ceil"([[DIV]])
// CHECK-DAG: [[CONVERT3:%.+]] = "xla_hlo.convert"([[CEIL]])
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[CONVERT3]])
// CHECK-DAG: [[IOTA:%.+]] = "xla_hlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64}
// CHECK-DAG: [[CONVERT3:%.+]] = "xla_hlo.convert"(%arg0)
// CHECK-DAG: [[CONVERT4:%.+]] = "xla_hlo.convert"(%arg2)
// CHECK-DAG: [[MUL:%.+]] = xla_chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
%2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
// CHECK: return [[ADD]]
return %2 : tensor<?xf32>
}
// CHECK-LABEL: func @range_int_dynamic
// CHECK-SAME: [[START:%.*]]: tensor<i32>, [[DELTA:%.*]]: tensor<i32>
func @range_int_dynamic(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?xi32> {
// CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract %arg1, %arg0
// CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"([[SUB]])
// CHECK-DAG: [[CONVERT1:%.+]] = "xla_hlo.convert"([[ABS1]])
// CHECK-DAG: [[CONVERT2:%.+]] = "xla_hlo.convert"(%arg2)
// CHECK-DAG: [[DIV:%.+]] = xla_hlo.divide [[CONVERT1]], [[CONVERT2]]
// CHECK-DAG: [[CEIL:%.+]] = "xla_hlo.ceil"([[DIV]])
// CHECK-DAG: [[CONVERT3:%.+]] = "xla_hlo.convert"([[CEIL]])
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[CONVERT3]])
// CHECK-DAG: [[IOTA:%.+]] = "xla_hlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64}
// CHECK-DAG: [[CONVERT3:%.+]] = "xla_hlo.convert"(%arg0)
// CHECK-DAG: [[CONVERT4:%.+]] = "xla_hlo.convert"(%arg2)
// CHECK-DAG: [[MUL:%.+]] = xla_chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
%2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
// CHECK: return [[ADD]]
return %2 : tensor<?xi32>
}
// CHECK-LABEL: func @linspace_static
// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[STOP:%.*]]: tensor<f32>
func @linspace_static(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<4xf32> {

View File

@ -41,12 +41,12 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK: [[C56:%.*]] = constant 56 : index
// CHECK: [[C1:%.*]] = constant 1 : index
// CHECK: [[C0_F32:%.*]] = constant 0.000000e+00 : f32
// CHECK: [[CFALSE:%.*]] = constant 0 : i1
// CHECK: [[CFALSE:%.*]] = constant false
// CHECK: [[C3:%.*]] = constant 3 : index
// CHECK: [[C2:%.*]] = constant 2 : index
// CHECK: [[C0:%.*]] = constant 0 : index
// CHECK: [[C112:%.*]] = constant 112 : index
// CHECK: [[CTRUE:%.*]] = constant 1 : i1
// CHECK: [[CTRUE:%.*]] = constant true
// Parallel loop to initialize the output buffer.
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>

View File

@ -228,32 +228,54 @@ func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
// CHECK-LABEL: func @broadcast_in_dim_with_expansion
func @broadcast_in_dim_with_expansion(%operand: memref<5x7x1xf32>,
%result: memref<7x10x6x4x5xf32>) {
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @static_broadcast_in_dim_no_expansion
func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
%result: memref<5x10xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>
} : (memref<5x7x1xf32>, memref<7x10x6x4x5xf32>) -> ()
broadcast_dimensions = dense<[0]> : tensor<1xi64>
} : (memref<5xf32>, memref<5x10xf32>) -> ()
return
}
// CHECK-NOT: linalg.reshape
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[RESULT_MAP_0:.*]] = affine_map<(d0, d1, d2) -> ()>
// CHECK-DAG: #[[REASSOCIATION:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_in_dim_scalar
func @broadcast_in_dim_scalar(%operand: memref<f32>,
%result: memref<7x10x6xf32>) {
// CHECK-LABEL: func @static_broadcast_in_dim_expansion
func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
%result: memref<5x10x100xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[]> : tensor<0xi64>
} : (memref<f32>, memref<7x10x6xf32>) -> ()
broadcast_dimensions = dense<[2, 0]> : tensor<2xi64>
} : (memref<1x5xf32>, memref<5x10x100xf32>) -> ()
return
}
// CHECK: %[[RESHAPED_ARG:.*]] = linalg.reshape %{{.*}}#[[REASSOCIATION]]]
// CHECK-SAME: memref<1x5xf32> into memref<5xf32>
// CHECK: linalg.generic {{{.*}}indexing_maps =
// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]{{.*}} %[[RESHAPED_ARG]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[RESULT_MAP_0:.*]] = affine_map<(d0, d1) -> ()>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @static_broadcast_in_dim_scalar
func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
%result: memref<5x10xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[]> : tensor<0xi64>
} : (memref<f32>, memref<5x10xf32>) -> ()
return
}
// CHECK-NOT: linalg.reshape
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP_0]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[CONST:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[CONST]] : f32
@ -262,19 +284,39 @@ func @broadcast_in_dim_scalar(%operand: memref<f32>,
// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one
func @broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>, %result: memref<1x5xf32>) {
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_one
func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
%result: memref<1x5xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[0]> : tensor<1xi64>
} : (memref<1xf32>, memref<1x5xf32>) -> ()
return
}
// CHECK-NOT: linalg.reshape
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_many
func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
%result: memref<5x5xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[1]> : tensor<1xi64>
} : (memref<1xf32>, memref<5x5xf32>) -> ()
return
}
// CHECK-NOT: linalg.reshape
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[VALUE:.*]] = load %{{.*}}[[C0]]
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[VALUE]] : f32
// -----
// CHECK-LABEL: func @constant
func @constant(%value: memref<i32>) {
"xla_lhlo.constant"(%value) {

View File

@ -150,7 +150,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
// CHECK-SAME: [[OPERAND_BUF:%.*]]: memref<112x112xf32>,
// CHECK-SAME: [[INIT_BUF:%.*]]: memref<f32>,
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<56x56xf32>) {
// CHECK-DAG: [[IN_BOUNDS:%.*]] = constant 1 : i1
// CHECK-DAG: [[IN_BOUNDS:%.*]] = constant true
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK-DAG: [[C2:%.*]] = constant 2 : index

View File

@ -830,6 +830,13 @@ func @get_tuple_element(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
// -----
func @get_tuple_element_token(%arg0: tuple<tensor<f32>, !xla_hlo.token>) -> !xla_hlo.token {
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple<tensor<f32>, !xla_hlo.token>) -> !xla_hlo.token
return %0 : !xla_hlo.token
}
// -----
func @get_tuple_element_bad_type(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<i32> {
// expected-error@+1 {{has return type tensor<i32>, but expected tensor<f32>}}
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<i32>

View File

@ -963,9 +963,19 @@ func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// -----
// The following op sharding is used:
// Proto debug string:
// type: OTHER
// tile_assignment_dimensions: 1
// tile_assignment_dimensions: 2
// tile_assignment_devices: 0
// tile_assignment_devices: 1
// Serialized string:
// "\08\03\1A\02\01\02\22\02\00\01"
// CHECK: HloModule
func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> {
%0 = "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "type: OTHER\ntile_assignment_dimensions: 1\ntile_assignment_dimensions: 2\ntile_assignment_devices: 0\ntile_assignment_devices: 1"} : (tensor<16x16xf32>) -> tensor<16x16xf32>
%0 = "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"} : (tensor<16x16xf32>) -> tensor<16x16xf32>
return %0 : tensor<16x16xf32>
}
@ -978,7 +988,7 @@ func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> {
// Tests that the exported HLO module keeps parameter replication annotation.
// CHECK: HloModule
func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {tf_device.is_same_data_across_replicas = true}) -> tensor<16x16xf32> {
func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {xla_hlo.is_same_data_across_replicas}) -> tensor<16x16xf32> {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
return %0 : tensor<16x16xf32>
}
@ -1008,19 +1018,19 @@ func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f64>>) -> (ten
// -----
// CHECK: HloModule
func @main(%arg0: tensor<4xui8>) -> (tensor<4xui8>) {
func @main(%arg0: tensor<4xui8>) -> tensor<4xui8> {
%0 = "xla_hlo.not"(%arg0) : (tensor<4xui8>) -> tensor<4xui8>
return %0 : tensor<4xui8>
}
// CHECK: ENTRY
// CHECK: %[[ARG0:.*]] = u8[4] parameter(0)
// ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]])
// CHECK: ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]])
// -----
// CHECK: HloModule
func @main(%arg0: tensor<4xi32>) -> (tensor<*xi32>) {
func @main(%arg0: tensor<4xi32>) -> tensor<*xi32> {
%0 = "xla_hlo.not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
%1 = tensor_cast %0 : tensor<4xi32> to tensor<*xi32>
return %1 : tensor<*xi32>
@ -1028,4 +1038,52 @@ func @main(%arg0: tensor<4xi32>) -> (tensor<*xi32>) {
// CHECK: ENTRY
// CHECK: %[[ARG0:.*]] = s32[4] parameter(0)
// ROOT %[[RESULT:.*]] = s32[4] not(s32[4] %[[ARG0]])
// CHECK: ROOT %[[RESULT:.*]] = s32[4] not(s32[4] %[[ARG0]])
// -----
// Tests ops with different frontend attributes have such attributes set
// correctly in HloModule as frontend_attributes.
// CHECK: HloModule
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> tuple<tensor<3x4xf32>, !xla_hlo.token> {
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_dtoh_0"}} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
%1 = "xla_hlo.recv"(%0) {channel_id = {handle = 2 : i64, type = 3 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_htod_0"}} : (!xla_hlo.token) -> tuple<tensor<3x4xf32>, !xla_hlo.token>
return %1 : tuple<tensor<3x4xf32>, !xla_hlo.token>
}
// CHECK: ENTRY
// CHECK: %[[SEND:.*]] = (f32[3,4], u32[], token[]) send
// CHECK-SAME: frontend_attributes={_xla_host_transfer_original_type="f32",_xla_host_transfer_rendezvous="channel_dtoh_0"}
// CHECK: %[[SEND_DONE:.*]] = token[] send-done((f32[3,4], u32[], token[]) %[[SEND]])
// CHECK-SAME: frontend_attributes={_xla_host_transfer_original_type="f32",_xla_host_transfer_rendezvous="channel_dtoh_0"}
// CHECK: %[[RECV:.*]] = (f32[3,4], u32[], token[]) recv(token[] %[[SEND_DONE]])
// CHECK-SAME: frontend_attributes={_xla_host_transfer_original_type="f32",_xla_host_transfer_rendezvous="channel_htod_0"}
// CHECK: ROOT %{{.*}} = (f32[3,4], token[]) recv-done((f32[3,4], u32[], token[]) %[[RECV]])
// CHECK-SAME: frontend_attributes={_xla_host_transfer_original_type="f32",_xla_host_transfer_rendezvous="channel_htod_0"}
// -----
// Tests ops with empty frontend attributes do not have frontend_attributes
// populated in HloModule.
// CHECK: HloModule
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> !xla_hlo.token {
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {}} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
return %0 : !xla_hlo.token
}
// CHECK-NOT: frontend_attributes
// -----
// Tests ops with no frontend attributes do not have frontend_attributes
// populated in HloModule.
// CHECK: HloModule
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> !xla_hlo.token {
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
return %0 : !xla_hlo.token
}
// CHECK-NOT: frontend_attributes

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
@ -153,13 +154,78 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
auto loc = op.getLoc();
Value resultBuffer = InsertDynamicAllocAndDealloc(
loc, op.getResult(), op.output_dimensions(), &rewriter);
rewriter.create<xla_lhlo::BroadcastInDimOp>(loc, operands[0], resultBuffer,
op.broadcast_dimensions());
Value transformed_operand =
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
rewriter.create<xla_lhlo::BroadcastInDimOp>(
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
rewriter.replaceOp(op, {resultBuffer});
return success();
}
private:
// Inserts dynamic memref to change the layout of the memref to put 0-stride
// and size of the target dimension if size-1 dimension expansion is
// necessary.
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape();
SmallVector<Value, 2> sizes, strides;
sizes.reserve(operand_shape.size());
strides.reserve(operand_shape.size());
Value zero = b->create<ConstantIndexOp>(loc, 0);
Value one = b->create<ConstantIndexOp>(loc, 1);
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
Value broadcast_dim_value =
b->create<ConstantIndexOp>(loc, dim.value().getSExtValue());
Value result_dim_size = b->create<ExtractElementOp>(
loc, op.output_dimensions(), broadcast_dim_value);
Value operand_dim_size =
ShapedType::isDynamic(operand_shape[dim.index()])
? b->create<DimOp>(loc, operand, dim.index()).getResult()
: b->create<ConstantIndexOp>(loc, operand_shape[dim.index()])
.getResult();
// TODO(pifon): Revisit if this cast is needed. Maybe we can use
// tensor<index> for `output_dimensions` as well.
if (!result_dim_size.getType().isIndex()) {
result_dim_size =
b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
}
// There can be two cases:
// 1) Operand dim == result dim => expansion is not needed => stride := 1.
// 2) Operand dim < result dim => expansion is needed => stride := 0.
Value is_expansion = b->create<CmpIOp>(loc, CmpIPredicate::slt,
operand_dim_size, result_dim_size);
strides.push_back(
b->create<mlir::SelectOp>(loc, is_expansion, zero, one));
// Size of input dim can be set to the size of the corresponding output
// dimension for both cases.
sizes.push_back(result_dim_size);
}
// Type-erased memref type with static rank, dynamic sizes and strides.
SmallVector<int64_t, 2> dynamic_layout(operand_shape.size(),
MemRefType::kDynamicStrideOrOffset);
SmallVector<int64_t, 2> dynamic_shape(operand_shape.size(),
MemRefType::kDynamicSize);
auto type_erased_memref_type = MemRefType::get(
dynamic_shape, operand_type.getElementType(),
makeStridedLinearLayoutMap(dynamic_layout,
/*offset=*/0, b->getContext()));
auto transformed_operand = b->create<xla_lhlo::DynamicMemRefCastOp>(
loc, type_erased_memref_type, operand, sizes, strides);
return transformed_operand;
}
};
struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
@ -281,7 +347,6 @@ class HloToLhloTensorStoreOpConverter
// "xla_lhlo.terminator"() : () -> ()
// }) : () -> ()
// return
// }
// }
//
// FuncOp signature conversion example:

View File

@ -2764,6 +2764,86 @@ class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
}
};
// Converts RangeOp for cases with the length is a dynamic value. The shape of
// the resulting tensor computed, then the start and delta is used with the
// dynamic_iota value to compute the final range value.
//
// For example, the resulting range op value:
// %range = "tf.range"(%start, %limit, %delta)
//
// Is converted to the following.
// %start + %delta * iota(ceil(abs((%limit - %start) / %delta))
//
// Implementation is defined in C++ due to the complicated type behavior.
class ConvertDynamicRangeOp : public OpRewritePattern<TF::RangeOp> {
using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::RangeOp op,
PatternRewriter &rewriter) const override {
auto result = op.getResult();
auto result_type = result.getType().cast<ShapedType>();
if (result_type.hasStaticShape()) {
return failure();
}
Value start = op.start();
Value delta = op.delta();
Value limit = op.limit();
// To compute the length we need to use floating point calculations so that
// ceil can be computed for the number of steps.
auto compute_element_type =
getElementTypeOrSelf(start.getType()).isa<FloatType>()
? getElementTypeOrSelf(start.getType())
: rewriter.getF64Type();
auto compute_type = RankedTensorType::get(
limit.getType().cast<ShapedType>().getShape(), compute_element_type);
// Compute the length of the sequence we are going to need. This includes
// some conversion to float for the operations.
//
// %size = ceil(abs((%limit - %start) / %delta))
auto range = rewriter.create<xla_hlo::SubOp>(op.getLoc(), limit, start);
auto abs = rewriter.create<xla_hlo::AbsOp>(op.getLoc(), range);
// Delta is not necessarily the same type as start and limit.
auto abs_cast =
rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), compute_type, abs);
auto delta_cast =
rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), compute_type, delta);
// Compute the total number of integer steps and convert to the HLO
// dimension tensor.
auto normalized =
rewriter.create<xla_hlo::DivOp>(op.getLoc(), abs_cast, delta_cast);
auto ceil = rewriter.create<xla_hlo::CeilOp>(op.getLoc(), normalized);
auto steps = rewriter.create<xla_hlo::ConvertOp>(
op.getLoc(), RankedTensorType::get({}, rewriter.getI64Type()), ceil);
auto reshape = rewriter.create<xla_hlo::ReshapeOp>(
op.getLoc(), RankedTensorType::get({1}, rewriter.getI64Type()), steps);
// Using the resulting length compute the correct range value:
//
// %range = %start + %delta * iota(%size)
auto out_scalar_type =
RankedTensorType::get({}, getElementTypeOrSelf(result_type));
auto start_out_cast = rewriter.create<xla_hlo::ConvertOp>(
op.getLoc(), out_scalar_type, start);
auto delta_out_cast = rewriter.create<xla_hlo::ConvertOp>(
op.getLoc(), out_scalar_type, delta);
auto iota = rewriter.create<DynamicIotaOp>(
op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0));
auto scaled = rewriter.create<xla_chlo::BroadcastMulOp>(
op.getLoc(), result_type, iota, delta_out_cast,
xla::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast));
rewriter.replaceOpWithNewOp<xla_chlo::BroadcastAddOp>(
op, result_type, scaled, start_out_cast,
xla::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast));
return success();
}
};
ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) {
auto int_attr = attr.cast<DenseIntElementsAttr>();
auto type = val.getType().cast<ShapedType>();
@ -3222,16 +3302,19 @@ class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
// Line input up with the next dimension in broadcasted_shape
// when broadcasting.
broadcast_dimensions.push_back(broadcasted_shape.size());
int64_t broadcast_dim;
int64_t output_size = input_size * multiple;
if (input_size == 1 || multiple == 1) {
// Special case for when normal broadcasting will just work.
broadcast_dim = broadcasted_shape.size();
broadcasted_shape.push_back(output_size);
} else {
// Tiling will happen for this dimension during the ReshapeOp below.
broadcasted_shape.push_back(input_size);
broadcasted_shape.push_back(multiple);
broadcast_dim = broadcasted_shape.size();
broadcasted_shape.push_back(input_size);
}
broadcast_dimensions.push_back(broadcast_dim);
}
Location loc = op.getLoc();
Type broadcasted_type =
@ -3803,17 +3886,15 @@ class ConvertInfeedDequeueTupleOp
// Token is a control signal and not a real data, so arbitrarily assign
// the token to device 0.
if (sharding_proto.type() == ::xla::OpSharding::TUPLE)
if (sharding_proto.type() == ::xla::OpSharding::TUPLE) {
*sharding_proto.add_tuple_shardings() =
::xla::sharding_builder::AssignDevice(0);
std::string sharding_str;
if (!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto,
&sharding_str))
return failure();
data_and_token.setAttr(kShardingAttr,
rewriter.getStringAttr(sharding_str));
data_and_token.setAttr(
kShardingAttr,
rewriter.getStringAttr(sharding_proto.SerializeAsString()));
} else {
data_and_token.setAttr(kShardingAttr, op._XlaShardingAttr());
}
}
// The infeed instruction produces a tuple of the infeed data and a token
@ -4356,21 +4437,12 @@ class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
// using a string.
if (!op._XlaSharding().hasValue()) return failure();
// _XlaSharding attribute in TF is a serialized string of the OpSharding
// proto, so convert to a text form here.
::xla::OpSharding sharding_proto;
std::string sharding_str;
if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()) ||
!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto,
&sharding_str))
return failure();
auto custom_call = rewriter.create<xla_hlo::CustomCallOp>(
op.getLoc(), op.getType(), op.input(),
/*call_target_name=*/rewriter.getStringAttr("Sharding"),
/*has_side_effect=*/rewriter.getBoolAttr(false),
/*backend_config=*/rewriter.getStringAttr(""));
custom_call.setAttr(kShardingAttr, rewriter.getStringAttr(sharding_str));
custom_call.setAttr(kShardingAttr, op._XlaShardingAttr());
rewriter.replaceOp(op, custom_call.getResult());
return success();
@ -4538,6 +4610,60 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
}
};
// Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
// dialect lowerings. This involves extracting the shape type, extracting and
// converting each dimension to a known integer type, and repacking into a final
// tensor.
class ConvertShapeOp : public OpRewritePattern<TF::ShapeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TF::ShapeOp op,
PatternRewriter &rewriter) const override {
Value input = op.input();
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
// If the shape is static it can be canonicalized.
if (!input_ty || input_ty.hasStaticShape()) {
return failure();
}
auto result_ty = op.getResult().getType().cast<RankedTensorType>();
auto element_ty = result_ty.getElementType();
int64_t rank = input_ty.getRank();
auto shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), input);
auto index_ty = RankedTensorType::get({1}, element_ty);
llvm::SmallVector<Value, 4> dim_values;
for (int64_t i = 0; i < rank; ++i) {
if (!input_ty.isDynamicDim(i)) {
auto dim_attr = DenseElementsAttr::get(
index_ty,
rewriter.getIntegerAttr(element_ty, input_ty.getDimSize(i)));
auto index = rewriter.create<xla_hlo::ConstOp>(op.getLoc(), dim_attr);
dim_values.push_back(index);
continue;
}
auto extent_op = rewriter.create<shape::GetExtentOp>(
op.getLoc(), shape_op, rewriter.getI64IntegerAttr(i));
auto index_op = rewriter.create<shape::SizeToIndexOp>(
op.getLoc(), rewriter.getIndexType(), extent_op);
auto int_op =
rewriter.create<IndexCastOp>(op.getLoc(), element_ty, index_op);
auto from_tensor = rewriter.create<TensorFromElementsOp>(
op.getLoc(), int_op.getResult());
auto reshape_op =
rewriter.create<ReshapeOp>(op.getLoc(), index_ty, from_tensor);
dim_values.push_back(reshape_op);
}
rewriter.replaceOpWithNewOp<ConcatenateOp>(op, result_ty, dim_values,
rewriter.getI64IntegerAttr(0));
return success();
}
};
// Converts a TF QR op to HLO.
class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
public:
@ -5116,8 +5242,9 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp,
ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
ConvertProdOp, ConvertQrOp, ConvertRangeOp, ConvertSelectV2Op,
ConvertSigmoidOp, ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp, ConvertRangeOp,
ConvertSelectV2Op, ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,

View File

@ -271,16 +271,17 @@ class BroadcastConverter
}
};
template <typename OpTy, bool isLHLO = true>
class BroadcastInDimConverter
: public DataMovementOpConverter<BroadcastInDimConverter<OpTy, isLHLO>,
OpTy, isLHLO> {
class HloBroadcastInDimConverter
: public DataMovementOpConverter<HloBroadcastInDimConverter,
xla_hlo::BroadcastInDimOp, false> {
public:
using DataMovementOpConverter<BroadcastInDimConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter;
using DataMovementOpConverter<HloBroadcastInDimConverter,
xla_hlo::BroadcastInDimOp,
false>::DataMovementOpConverter;
static ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) {
auto resultType = getXLAOpResultType<isLHLO>(broadcastOp);
static ArrayAttr getIndexingMapsAttr(xla_hlo::BroadcastInDimOp broadcastOp,
Builder* b) {
auto resultType = getXLAOpResultType<false>(broadcastOp);
auto operandType =
broadcastOp.operand().getType().template cast<ShapedType>();
unsigned nloops = resultType.getRank();
@ -302,8 +303,6 @@ class BroadcastInDimConverter
int size = broadcastDim.value().getSExtValue();
bool expansion_needed = operandShape[broadcastDim.index()] == 1 &&
resultType.getShape()[size] != 1;
// TODO(pifon): Add support for args with dynamic shapes for the case
// when a dimension of size 1 is broadcasted into dim of size N.
dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
: b->getAffineDimExpr(size));
}
@ -314,6 +313,181 @@ class BroadcastInDimConverter
}
};
class LhloBroadcastInDimConverter
: public OpConversionPattern<xla_lhlo::BroadcastInDimOp> {
public:
using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
xla_lhlo::BroadcastInDimOpOperandAdaptor operand_adaptor(args);
auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
auto result_shape = result_type.getShape();
auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter);
Value operand = std::get<0>(operand_and_dims);
auto broadcast_dims = std::get<1>(operand_and_dims);
auto loc = op.getLoc();
auto nloops = result_type.getRank();
auto operand_type = operand.getType().cast<MemRefType>();
// For a degenerate case, i.e. broadcasting with expansion of
// memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`.
// Instead the value is loaded and used directly in `linalg.yield`.
if (operand_type.getRank() == 1 &&
operand_type.getDimSize(0) <
result_type.getDimSize(broadcast_dims.front())) {
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value val =
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()),
rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(1),
rewriter.getAffineMapArrayAttr(
{rewriter.getMultiDimIdentityMap(nloops)}),
GetNParallelLoopsAttrs(nloops, &rewriter),
/*doc=*/nullptr, /*library_call=*/nullptr);
auto* region = &linalgOp.region();
auto* block = rewriter.createBlock(region, region->end());
block->addArgument(result_type.getElementType());
rewriter.setInsertionPointToEnd(block);
rewriter.create<linalg::YieldOp>(loc, val);
} else {
ArrayAttr indexingMapsAttr = getIndexingMapsAttr(
op, broadcast_dims, result_shape, operand_type, &rewriter);
OpBuilder::InsertionGuard linalgOpGuard(rewriter);
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, llvm::None,
llvm::makeArrayRef({operand, operand_adaptor.output()}),
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1),
indexingMapsAttr, GetNParallelLoopsAttrs(nloops, &rewriter),
/*doc=*/nullptr, /*library_call=*/nullptr);
auto* region = &linalgOp.region();
auto* block = rewriter.createBlock(region, region->end());
block->addArguments(operand_type.getElementType());
block->addArgument(result_type.getElementType());
rewriter.setInsertionPointToEnd(block);
rewriter.create<linalg::YieldOp>(loc, block->getArgument(0));
}
rewriter.replaceOp(op, llvm::None);
return success();
}
// Inserts 'linalg.reshape' if there is a size-1 dim expansion.
std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const {
xla_lhlo::BroadcastInDimOpOperandAdaptor operand_adaptor(args);
Value operand = operand_adaptor.operand();
auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape();
Value result = operand_adaptor.output();
auto result_type = result.getType().cast<MemRefType>();
auto result_shape = result_type.getShape();
SmallVector<int64_t, 2> operand_strides;
int64_t operand_offset;
if (failed(getStridesAndOffset(operand_type, operand_strides,
operand_offset))) {
op.emitOpError() << "Failed to get offset and strides.";
}
SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
SmallVector<SmallVector<AffineExpr, 2>, 4> collapsed_dims_list;
SmallVector<AffineExpr, 2> collapsed_dims;
for (const auto& item :
enumerate(op.broadcast_dimensions().getIntValues())) {
size_t index = item.index();
int dim = item.value().getSExtValue();
collapsed_dims.push_back(rewriter.getAffineDimExpr(index));
bool expansion_needed =
operand_shape[index] == 1 && result_shape[dim] != 1;
if (expansion_needed) {
continue;
}
new_shape.push_back(operand_shape[index]);
new_strides.push_back(operand_strides[index]);
broadcast_dims.push_back(dim);
collapsed_dims_list.push_back(collapsed_dims);
collapsed_dims.clear();
}
// If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1]
// and all dimensions need expansion. Such memref will be reshaped to a 1D
// memref with a single element. New shape and strides needs to be updated
// accordingly.
if (collapsed_dims_list.empty()) {
collapsed_dims_list.push_back({});
new_shape.push_back(1);
new_strides.push_back(1);
broadcast_dims.push_back(0);
}
for (const auto& dims : collapsed_dims) {
collapsed_dims_list.back().push_back(dims);
}
// `linalg.reshape` is inserted only if necessary, i.e. when the rank can be
// reduced.
if (new_shape.size() < operand_shape.size()) {
SmallVector<ArrayRef<AffineExpr>, 4> reassociation_maps;
for (const auto& dims : collapsed_dims_list)
reassociation_maps.push_back(dims);
auto new_memref_type = MemRefType::get(
new_shape, operand_type.getElementType(),
makeStridedLinearLayoutMap(new_strides, operand_offset,
rewriter.getContext()));
operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type,
operand_adaptor.operand(),
reassociation_maps);
}
return std::make_pair(operand, broadcast_dims);
}
ArrayAttr getIndexingMapsAttr(xla_lhlo::BroadcastInDimOp op,
ArrayRef<int64_t> broadcastDims,
ArrayRef<int64_t> resultShape,
MemRefType operandType, Builder* b) const {
unsigned nloops = resultShape.size();
// The input is a scalar, i.e. this is a scalar broadcast op.
if (operandType.getRank() == 0) {
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
}
auto operandShape = operandType.getShape();
SmallVector<AffineExpr, 4> dimExprs;
dimExprs.reserve(nloops);
for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) {
int size = broadcastDim.value();
bool expansion_needed =
operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1;
if (expansion_needed) {
op.emitOpError(
"BroadcastInDimOp lowering to Linalg does not support size-1 "
"dimensions expansion.");
}
dimExprs.push_back(b->getAffineDimExpr(size));
}
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
}
};
/// Pattern for the special case where reshape is adding or removing a dimension
/// of size 1. These can be lowered to a linalg.generic op.
///
@ -639,9 +813,9 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
BroadcastInDimConverter<xla_lhlo::BroadcastInDimOp>,
ConstConverter,
IotaConverter,
LhloBroadcastInDimConverter,
PointwiseToLinalgConverter<xla_lhlo::AbsOp>,
PointwiseToLinalgConverter<xla_lhlo::AddOp>,
PointwiseToLinalgConverter<xla_lhlo::AndOp>,
@ -690,13 +864,12 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
// ^bb0(%arg4: f32, %arg5: f32):
// %0 = addf %arg4, %arg5 : f32
// "linalg.yield"(%0) : (f32) -> ()
// }) {
// }) {
// args_in = 2,
// args_out = 1,
// indexing_maps = [#map0, #map0, #map0],
// iterator_types = ["parallel", "parallel"],
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// }
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
struct LhloLegalizeToLinalg
: public PassWrapper<LhloLegalizeToLinalg, FunctionPass> {
void runOnFunction() override {
@ -743,7 +916,7 @@ namespace xla_hlo {
void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
patterns->insert<BroadcastConverter<xla_hlo::BroadcastOp, false>,
BroadcastInDimConverter<xla_hlo::BroadcastInDimOp, false>,
HloBroadcastInDimConverter,
PointwiseToLinalgConverter<xla_hlo::AbsOp, false>,
PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,

View File

@ -1162,7 +1162,7 @@ tf_xla_py_test(
tf_xla_py_test(
name = "scan_ops_test",
size = "small",
size = "medium",
srcs = ["scan_ops_test.py"],
python_version = "PY3",
tags = [

View File

@ -1413,7 +1413,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
],
equality_test=self.ListsAreClose)
@test_util.disable_mlir_bridge("TODO(b/155097657): Debug incorrect answer")
def testTile(self):
for dtype in self.numeric_types:
self._testBinary(

View File

@ -555,6 +555,7 @@ cc_library(
hdrs = ["convert/utils.h"],
copts = tf_copts(),
deps = [
"@com_google_absl//absl/algorithm:container",
"//tensorflow/core:framework",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:lib",

View File

@ -445,16 +445,32 @@ nvinfer1::ITensor* Converter::CreateConstantLayer(
return trt_tensor;
}
// Creates a scalar constant and fills with value.
template <typename T>
Status CreateScalarConstant(
OpConverterParams* params, T value, nvinfer1::ITensor** tensor,
nvinfer1::DataType trt_type = nvinfer1::DataType::kINT32,
const nvinfer1::Dims& dims = {1, {1}}) {
TRT_ShapedWeights weights =
params->weight_store->GetTempWeights(trt_type, dims);
TF_RETURN_IF_ERROR(weights.SetValues(value));
*tensor = params->converter->CreateConstantLayer(weights, dims);
TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name());
params->converter->ProvideQuantizationRange(*tensor, value, value);
return Status::OK();
}
// Creates a constant with the same rank as dims, where each dimension has
// size = 1.
Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
const nvinfer1::Dims& dims,
nvinfer1::ITensor** tensor,
const char* dtype_attr_name = "T") {
nvinfer1::DataType trt_dtype =
nvinfer1::DataType::kFLOAT; // Default to FP32.
nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT; // Default to FP32.
TFAttrs attrs(params->node_def);
if (attrs.count(dtype_attr_name)) {
DataType dtype = attrs.get<DataType>(dtype_attr_name);
TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, &trt_dtype));
TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, &trt_type));
}
// In order to be broadcastable, the number of dims has to match.
@ -462,24 +478,8 @@ Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
for (int i = 0; i < broadcastable_dims.nbDims; i++) {
broadcastable_dims.d[i] = 1;
}
TRT_ShapedWeights weights =
params->weight_store->GetTempWeights(trt_dtype, broadcastable_dims);
void* raw_ptr = weights.GetValues();
switch (trt_dtype) {
case nvinfer1::DataType::kFLOAT:
static_cast<float*>(raw_ptr)[0] = value;
break;
case nvinfer1::DataType::kHALF:
static_cast<Eigen::half*>(raw_ptr)[0] = Eigen::half(value);
break;
default:
return errors::InvalidArgument("Unsupported data type ",
DebugString(trt_dtype));
}
*tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims);
TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name());
params->converter->ProvideQuantizationRange(*tensor, value, value);
return Status::OK();
return CreateScalarConstant(params, value, tensor, trt_type,
broadcastable_dims);
}
// Convert an axis from TF format to TRT format while validating. TF format
@ -663,6 +663,31 @@ nvinfer1::Weights TRT_ShapedWeights::GetTrtWeights() const {
return nvinfer1::Weights{type_, GetValues(), count()};
}
template <typename T>
Status TRT_ShapedWeights::SetValues(T value) {
switch (type_) {
case nvinfer1::DataType::kFLOAT: {
float* ptr = tensor_.flat<float>().data();
std::fill(ptr, ptr + count(), value);
break;
}
case nvinfer1::DataType::kHALF: {
Eigen::half* ptr = tensor_.flat<Eigen::half>().data();
std::fill(ptr, ptr + count(), Eigen::half(value));
break;
}
case nvinfer1::DataType::kINT32: {
int32* ptr = tensor_.flat<int32>().data();
std::fill(ptr, ptr + count(), value);
break;
}
default:
return errors::InvalidArgument("Unsupported data type ",
tensorflow::tensorrt::DebugString(type_));
}
return Status::OK();
}
size_t TRT_ShapedWeights::size_bytes() const {
size_t data_type_size = -1;
switch (type_) {
@ -1297,10 +1322,14 @@ Status Converter::AddInputTensor(const string& name, nvinfer1::DataType dtype,
// We verify the batch size only for the input nodes, and rely on individual
// op converter to ensure the batch size of the outputs is not changed.
// TODO(laigd): we need to test this properties.
Status status = MaybeUpdateBatchSize(batch_size);
if (!status.ok()) {
return Status(status.code(), StrCat("Batch size doesn't match for tensor ",
name, ": ", status.error_message()));
Status status;
if (use_implicit_batch_) {
status = MaybeUpdateBatchSize(batch_size);
if (!status.ok()) {
return Status(status.code(),
StrCat("Batch size doesn't match for tensor ", name, ": ",
status.error_message()));
}
}
nvinfer1::ITensor* tensor = network()->addInput(name.c_str(), dtype, dims);
if (tensor == nullptr) {
@ -1893,32 +1922,34 @@ Status Converter::GetInputs(const NodeDef& node_def,
return Status::OK();
}
enum class TrtInputArg { kTensor = 1, kWeight = 2, kBoth = 3 };
// Checks that the number of inputs match, and enforces that the inputs marked
// as true are constant weights. true means that the input must be a weight,
// while false means the input must be a tensor. In the future, false will mean
// the input can be a tensor or weight.
// as weights are constant. Inputs are allowed to be both weight and tensor.
Status CheckInputsWeights(
const OpConverterParams& params,
const std::vector<std::pair<string, bool>>& inputs_is_weight) {
const std::vector<std::pair<string, TrtInputArg>>& expected_inputs) {
const auto& inputs = params.inputs;
const auto& node_def = params.node_def;
if (inputs.size() != inputs_is_weight.size()) {
if (inputs.size() != expected_inputs.size()) {
return errors::InvalidArgument(
node_def.op(), " got ", inputs.size(), " inputs but expected ",
inputs_is_weight.size(), ", at ", node_def.name());
expected_inputs.size(), ", at ", node_def.name());
}
for (int i = 0; i < inputs.size(); i++) {
if (inputs_is_weight[i].second && inputs.at(i).is_tensor()) {
return errors::Unimplemented("The input \"", inputs_is_weight[i].first,
if (expected_inputs[i].second == TrtInputArg::kWeight &&
inputs.at(i).is_tensor()) {
return errors::Unimplemented("The input \"", expected_inputs[i].first,
"\" for ", node_def.op(),
" must be a constant, at ", node_def.name());
}
// TODO(tmorris): Remove this check and provide a method to automatically
// TODO(tfeher): Remove this check and provide a method to automatically
// retrieve an input as a tensor, converting via CreateConstantLayer if it
// was originally a weight. We will want a caching mechanism to prevent many
// duplicate constants from being created.
if (!inputs_is_weight[i].second && inputs.at(i).is_weights()) {
return errors::Unimplemented("The input \"", inputs_is_weight[i].first,
if (expected_inputs[i].second == TrtInputArg::kTensor &&
inputs.at(i).is_weights()) {
return errors::Unimplemented("The input \"", expected_inputs[i].first,
"\" for ", node_def.op(),
" must be a tensor, at ", node_def.name());
}
@ -1926,6 +1957,23 @@ Status CheckInputsWeights(
return Status::OK();
}
// Checks that the number of inputs match, and enforces that the inputs marked
// as true are constant weights. true means that the input must be a weight,
// while false means the input must be a tensor.
Status CheckInputsWeights(
const OpConverterParams& params,
const std::vector<std::pair<string, bool>>& inputs_is_weight) {
std::vector<std::pair<string, TrtInputArg>> expected_inputs;
expected_inputs.reserve(inputs_is_weight.size());
std::transform(
inputs_is_weight.begin(), inputs_is_weight.end(),
std::back_inserter(expected_inputs), [](std::pair<string, bool> x) {
return std::make_pair(
x.first, x.second ? TrtInputArg::kWeight : TrtInputArg::kTensor);
});
return CheckInputsWeights(params, expected_inputs);
}
Status GetNodeDefTfType(const NodeDef& node_def, DataType* tf_type,
const char* type_attr_name) {
TFAttrs attrs(node_def);
@ -2451,53 +2499,114 @@ Status ConvertExpandDims(OpConverterParams* params) {
// ExpandDim's ability to add an axis at end of the shape.
int trt_axis;
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], dims.nbDims + 1, node_def.name(),
/*use_implicit_batch=*/true, &trt_axis));
params->use_implicit_batch, &trt_axis));
if (params->validation_only) return Status::OK();
// ExpandDims: Insert new dim of size 1.
input_dims.insert(input_dims.begin() + trt_axis, 1);
// Reshape tensor.
nvinfer1::Dims new_dims;
TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims));
nvinfer1::ITensor* output_tensor = nullptr;
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
input_tensor, new_dims, /*validation_only=*/false, &output_tensor));
if (!params->use_implicit_batch && !HasStaticShape(input_dims)) {
TF_RETURN_IF_ERROR(params->converter->DynamicExpandDims(
input_tensor.tensor(), dims, trt_axis, params, &output_tensor));
} else {
// ExpandDims: Insert new dim of size 1.
input_dims.insert(input_dims.begin() + trt_axis, 1);
// Reshape tensor.
nvinfer1::Dims new_dims;
TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims));
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
input_tensor, new_dims, /*validation_only=*/false, &output_tensor));
}
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
return Status::OK();
}
Status Converter::DynamicReshape(nvinfer1::ITensor* input,
std::vector<std::pair<int, int>> slices,
OpConverterParams* params,
nvinfer1::ITensor** output,
std::vector<int> size_for_added_dims) {
*output = nullptr;
// DynamicReshape relies on INetworkDefinition::addShape that was introduced
// in TensorRT 6.
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
if (params->validation_only) {
return errors::Internal(
"DynamicReshape should not be used during validation");
}
nvinfer1::ITensor* shape = network()->addShape(*input)->getOutput(0);
// Build new shape = shape[:trt_axis] + [1] + shape[trt_axis:]
std::vector<nvinfer1::ITensor const*> concat_inputs;
for (int i = 0; i < std::max(slices.size(), size_for_added_dims.size());
i++) {
nvinfer1::ITensor* tensor;
// maybe_add_a_dimension(i);
if (i < size_for_added_dims.size() && size_for_added_dims[i] >= 0) {
TF_RETURN_IF_ERROR(
CreateScalarConstant(params, size_for_added_dims[i], &tensor));
concat_inputs.push_back(tensor);
}
if (i < slices.size()) {
concat_inputs.push_back(
network()
->addSlice(*shape, {1, {slices[i].first}},
{1, {slices[i].second - slices[i].first}}, {1, {1}})
->getOutput(0));
}
}
nvinfer1::IConcatenationLayer* concat_layer = network()->addConcatenation(
const_cast<nvinfer1::ITensor* const*>(concat_inputs.data()),
concat_inputs.size());
concat_layer->setAxis(0);
nvinfer1::ITensor* new_shape = concat_layer->getOutput(0);
// Reshape input using new shape
nvinfer1::IShuffleLayer* shuffle = network()->addShuffle(*input);
shuffle->setInput(1, *new_shape);
*output = shuffle->getOutput(0);
return Status::OK();
#else
return errors::Unavailable(
"Dynamic shape input requires TensorRT 6 or above");
#endif
}
Status Converter::DynamicExpandDims(nvinfer1::ITensor* input,
const nvinfer1::Dims& dims, int axis,
OpConverterParams* params,
nvinfer1::ITensor** output) {
if (params->validation_only) {
*output = nullptr;
return errors::Internal(
"DynamicExpandDims should not be used during validation");
}
std::vector<std::pair<int, int>> slices;
std::vector<int> extra_dims;
if (axis != 0) {
slices.push_back(std::pair<int, int>{0, axis});
extra_dims.push_back(-1);
}
extra_dims.push_back(1);
if (axis != dims.nbDims) {
slices.push_back(std::pair<int, int>{axis, dims.nbDims});
}
return DynamicReshape(input, slices, params, output, extra_dims);
}
Status Converter::SqueezeTensor(nvinfer1::ITensor* input,
std::vector<int>* input_dims,
OpConverterParams* params,
nvinfer1::ITensor** output) {
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
// If the remaining dimensions of a squeeze operation have dynamic sizes, we
// need to use TRT ops to build the result shape for the squeeze operation.
// This is because IShuffleLayer::setReshapeDimensions treats -1 as a special
// value.
if (absl::c_any_of(*input_dims, [](int i) { return i == -1; })) {
nvinfer1::ITensor* shape = network()->addShape(*input)->getOutput(0);
std::vector<nvinfer1::ITensor const*> concat_inputs;
if (!params->use_implicit_batch && !HasStaticShape(*input_dims)) {
std::vector<std::pair<int, int>> slices;
for (int i = 0; i < input_dims->size(); i++) {
// If input dim wasn't set to 0 earlier, we include it in new shape.
if (input_dims->at(i) != 0) {
concat_inputs.push_back(
network()
->addSlice(*shape, {1, {i}}, {1, {1}}, {1, {1}})
->getOutput(0));
slices.push_back(std::pair<int, int>(i, i + 1));
}
}
nvinfer1::IConcatenationLayer* concat_layer = network()->addConcatenation(
const_cast<nvinfer1::ITensor* const*>(concat_inputs.data()),
concat_inputs.size());
concat_layer->setAxis(0);
nvinfer1::ITensor* new_shape = concat_layer->getOutput(0);
// Reshape input using new shape
nvinfer1::IShuffleLayer* shuffle = network()->addShuffle(*input);
shuffle->setInput(1, *new_shape);
*output = shuffle->getOutput(0);
return Status::OK();
return DynamicReshape(input, slices, params, output);
}
#endif
// Remove all dims which are equal to 0.
input_dims->erase(std::remove(input_dims->begin(), input_dims->end(), 0),
input_dims->end());
@ -2564,7 +2673,7 @@ Status ConvertSqueeze(OpConverterParams* params) {
nvinfer1::ITensor* output_tensor = nullptr;
TF_RETURN_IF_ERROR(params->converter->SqueezeTensor(
input_tensor.tensor(), &input_dims, &output_tensor));
input_tensor.tensor(), &input_dims, params, &output_tensor));
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
return Status::OK();
}
@ -4149,8 +4258,13 @@ Status ConvertBinary(OpConverterParams* params) {
" inputs but expected 2, at ",
node_def.name());
}
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF,
DataType::DT_INT32};
#else
std::set<DataType> allowed_types{DataType::DT_FLOAT, DataType::DT_HALF};
#endif
TF_RETURN_IF_ERROR(AllowDataTypes(*params, allowed_types));
// Constant folding should have been done by TensorFlow
if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
@ -4896,24 +5010,14 @@ Status ConvertGather(OpConverterParams* params) {
const auto& node_def = params->node_def;
// TODO(tmorris): Use CheckInputsWeights by changing bool to enum with an
// option for an input to be either tensor or weight.
if (inputs.size() != 3) {
return errors::InvalidArgument("GatherV2 got ", inputs.size(),
" inputs but expected 3, at ",
node_def.name());
}
TF_RETURN_IF_ERROR(
CheckInputsWeights(*params, {{"params", TrtInputArg::kBoth},
{"indices", TrtInputArg::kTensor},
{"axis", TrtInputArg::kWeight}}));
const auto& params_input = inputs.at(0);
const auto& indices_input = inputs.at(1);
const auto& axis_input = inputs.at(2);
if (!axis_input.is_weights()) {
return errors::Unimplemented(
"The input \"axis\" for GatherV2 must be a constant, at ",
node_def.name());
}
if (!indices_input.is_tensor()) {
return errors::Unimplemented(
"The input \"indices\" for GatherV2 must be a tensor, at ",
node_def.name());
}
TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32},
@ -4927,14 +5031,16 @@ Status ConvertGather(OpConverterParams* params) {
node_def.name());
}
int trt_axis = 0;
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], params_input.GetTrtDims().nbDims,
node_def.name(), params_input.is_tensor(),
&trt_axis));
if (params_input.is_weights() && trt_axis != 0) {
TF_RETURN_IF_ERROR(ConvertAxis(
axis[0], params_input.GetTrtDims().nbDims, node_def.name(),
params->use_implicit_batch && params_input.is_tensor(), &trt_axis));
if (params->use_implicit_batch && params_input.is_weights() &&
trt_axis != 0) {
return errors::Unimplemented(
"The input axis must be zero when params is a weight.");
}
if (params_input.is_tensor() && indices_input.batch_size() != 1) {
if (params->use_implicit_batch && params_input.is_tensor() &&
indices_input.batch_size() != 1) {
return errors::Unimplemented(
"Indices must have a batch size of 1 when params is a tensor.");
}
@ -4943,10 +5049,13 @@ Status ConvertGather(OpConverterParams* params) {
// where "+ 1" adds the batch dim. If params is a weight, the TRT rank matches
// the TF rank so we don't have to add + 1.
const int params_tf_rank =
params_input.GetTrtDims().nbDims + (params_input.is_tensor() ? 1 : 0);
const int indices_tf_rank = indices_input.GetTrtDims().nbDims + 1;
params_input.GetTrtDims().nbDims +
(params->use_implicit_batch && params_input.is_tensor() ? 1 : 0);
const int indices_tf_rank =
indices_input.GetTrtDims().nbDims + (params->use_implicit_batch ? 1 : 0);
const int tf_gather_output_rank = params_tf_rank + indices_tf_rank - 1;
if (tf_gather_output_rank > nvinfer1::Dims::MAX_DIMS + 1) {
if (tf_gather_output_rank >
nvinfer1::Dims::MAX_DIMS + (params->use_implicit_batch ? 1 : 0)) {
return errors::InvalidArgument(
"Result of gather has dimension greater than ",
nvinfer1::Dims::MAX_DIMS + 1);
@ -4978,7 +5087,8 @@ Status ConvertGather(OpConverterParams* params) {
// because of the implicit batch dim in the indices (see the above note).
const int expected_trt_output_rank =
tf_gather_output_rank - (params_input.is_tensor() ? 2 : 1);
if (trt_gather_output_dims.nbDims != expected_trt_output_rank) {
if (params->use_implicit_batch &&
trt_gather_output_dims.nbDims != expected_trt_output_rank) {
return errors::Internal(
"Get unexpected output dimensions of IGatherLayer. Expect nbDims: ",
expected_trt_output_rank,
@ -4986,7 +5096,7 @@ Status ConvertGather(OpConverterParams* params) {
}
// Reshape the output so after adding the implicit batch dim it'll match the
// output shape of TF GatherV2.
if (params_input.is_tensor()) {
if (params->use_implicit_batch && params_input.is_tensor()) {
for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) {
trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1];
}

View File

@ -185,6 +185,10 @@ class TRT_ShapedWeights {
return const_cast<char*>(tensor_.tensor_data().data());
}
// Fills all the weight values with value.
template <typename T>
Status SetValues(T value);
int64_t count() const;
size_t size_bytes() const;
@ -529,12 +533,62 @@ class Converter {
const bool validation_only,
nvinfer1::ITensor** tensor);
// Reshapes a dynamic shape tensor by removing or adding dimensions of size 1,
// and/or permuting the dimensions. The new shape is derived from the shape of
// the input tensor according to the slices and size_for_added_dims arguments.
//
// If there would be at most one unknown dimension, we could set the new shape
// using IShuffleLayer::setReshapeDimensions, which treats -1 as a special
// value (the same way as TF). In general, we can have more than one unknown
// dimensions, and we have to manipulate the shape tensors during runtime to
// define the new shape. This helper function defines the necessary shape
// inference layers and calls reshape using the calculated new shape.
//
// Example:
//
// Assume that we want to reshape a tensor from shape {A,B,C,D} to {C,D,A,B}
// (no transpose, just change the shape). In dynamic shape mode, the A,B,C,D
// values are not necessarily known at conversion time, they can be all -1. We
// can only define the new shape at runtime, when the actual shape is already
// known. To define the new shape:
// - We use an IShapeLayer to retrieve a shape tensor with the {A,B,C,D}
// values.
// - Create two slices {C,D} and {A,B} of the shape tensor.
// - Concatenate these slices {C,D,A,B},
// - Set the {C,D,A,B} shape tensor as an input shape tensor for
// IShuffleLayer.
//
// This can be achieved by calling DynamicReshape(input, {{2,4},{0,2}},
// params).
//
// Before each slice we can insert a new dim if the corresponding
// size_for_added_dims element is not negative. The size_for_added_dims array
// can have more than slices.size() elements, in order to insert a dimension
// ater the last slice.
//
// Parameters:
// input - input tensor
// slices - [start, end) pairs of slices
// params - conversion parameters
// output - reshaped tensor
// size_for_added_dims - size of dimension inserted right before slice[i]. We
// only insert a new dim if size_for_added_dims[i] >= 0.
Status DynamicReshape(nvinfer1::ITensor* input,
std::vector<std::pair<int, int>> slices,
OpConverterParams* params, nvinfer1::ITensor** output,
std::vector<int> size_for_added_dims = {});
// Inserts a singleton dimension at axis for a dynamic shape tensor.
Status DynamicExpandDims(nvinfer1::ITensor* input, const nvinfer1::Dims& dims,
int axis, OpConverterParams* params,
nvinfer1::ITensor** output);
// Helper function to add a squeeze op to the network.
//
// The input_dims argument stores the TRT dimensions of the input tensor,
// where the dimensions to be squeezed are replaced by 0.
Status SqueezeTensor(nvinfer1::ITensor* input, std::vector<int>* input_dims,
nvinfer1::ITensor** output);
OpConverterParams* params, nvinfer1::ITensor** output);
// Creates an IConstantLayer using 'weights' whose dimensions are specified by
// 'dims', and returns the output ITensor.
@ -632,6 +686,9 @@ Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
const std::unordered_map<string, nvinfer1::UnaryOperation>* UnaryOperationMap();
// Map of all supported ActivationTypes
const std::unordered_map<string, nvinfer1::ActivationType>* ActivationTypeMap();
// Map of all supported BinaryOperations
const std::unordered_map<string, nvinfer1::ElementWiseOperation>*
BinaryOperationMap();
} // namespace convert
} // namespace tensorrt

View File

@ -1359,25 +1359,25 @@ class OpConverterTest : public ::testing::Test {
}
// Constructs a tensor with given values (vals). The tensor type is defined by
// the tf_dtype argument, its shape is given by input_dims. The tensor is
// the tf_type argument, its shape is given by input_dims. The tensor is
// constructed using the allocator of OpConverterTest in Unified Memory.
template <typename T>
Tensor AsTensor(std::vector<T> vals, const std::vector<int> input_dims,
DataType tf_dtype) {
Tensor ret(allocator_.get(), tf_dtype, {static_cast<int64>(vals.size())});
if (tf_dtype == DT_FLOAT) {
DataType tf_type) {
Tensor ret(allocator_.get(), tf_type, {static_cast<int64>(vals.size())});
if (tf_type == DT_FLOAT) {
auto conv_vals = CastTestVector<T, float>(vals);
std::copy_n(conv_vals.data(), conv_vals.size(), ret.flat<float>().data());
} else if (tf_dtype == DT_HALF) {
} else if (tf_type == DT_HALF) {
auto conv_vals = CastTestVector<T, Eigen::half>(vals);
std::copy_n(conv_vals.data(), conv_vals.size(),
ret.flat<Eigen::half>().data());
} else if (tf_dtype == DT_INT32) {
} else if (tf_type == DT_INT32) {
auto conv_vals = CastTestVector<T, int32>(vals);
std::copy_n(conv_vals.data(), conv_vals.size(), ret.flat<int32>().data());
} else {
LOG(FATAL) << "Cannot create tensor with type "
<< DataTypeString(tf_dtype);
<< DataTypeString(tf_type);
}
TensorShape shape;
TF_EXPECT_OK(TensorShapeUtils::MakeShape(input_dims, &shape));
@ -1394,9 +1394,9 @@ class OpConverterTest : public ::testing::Test {
// Constructs a flat tensor in Unified Memory.
template <typename T>
Tensor ConstructTensor(int data_size, const T& value, DataType tf_dtype) {
Tensor ConstructTensor(int data_size, const T& value, DataType tf_type) {
std::vector<T> values(data_size, value);
return AsTensor<T>(values, {data_size}, tf_dtype);
return AsTensor<T>(values, {data_size}, tf_type);
}
void CheckDataTypeMatches(const DataVec& datas) {
@ -1405,10 +1405,10 @@ class OpConverterTest : public ::testing::Test {
ASSERT_NE(-1, input_index);
const nvinfer1::DataType trt_dtype =
engine_->getBindingDataType(input_index);
const DataType tf_dtype = TrtDataTypeToTf(trt_dtype);
ASSERT_EQ(data.tensor.dtype(), tf_dtype)
const DataType tf_type = TrtDataTypeToTf(trt_dtype);
ASSERT_EQ(data.tensor.dtype(), tf_type)
<< DataTypeString(data.tensor.dtype()) << " vs. "
<< DataTypeString(tf_dtype);
<< DataTypeString(tf_type);
}
}
@ -1489,12 +1489,13 @@ class OpConverterTest : public ::testing::Test {
// dimension is included in dims (ie for an NCHW tensor dims = {N, C, H, W}).
void AddTestTensorWithTFDims(
const string& name, const std::vector<int32>& dims,
nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) {
DataType tf_dtype = TrtDataTypeToTf(trt_dtype);
nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT,
Status add_input_status = Status::OK()) {
DataType tf_type = TrtDataTypeToTf(trt_type);
ops::Placeholder::Attrs attrs;
TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_));
auto input = ops::Placeholder(scope_.WithOpName(name), tf_dtype, attrs);
auto input = ops::Placeholder(scope_.WithOpName(name), tf_type, attrs);
node_inputs_[name] = input.output;
// Add a real ITensor for conversion conditionally.
@ -1502,8 +1503,9 @@ class OpConverterTest : public ::testing::Test {
TensorShapeToTrtDims(attrs.shape_, converter_->use_implicit_batch());
if (!converter_->use_implicit_batch() || HasStaticShape(trt_dims)) {
int batch_size = dims[0];
TF_EXPECT_OK(
converter_->AddInputTensor(name, trt_dtype, trt_dims, batch_size));
Status status =
converter_->AddInputTensor(name, trt_type, trt_dims, batch_size);
ASSERT_EQ(add_input_status, status);
}
}
@ -1552,24 +1554,23 @@ class OpConverterTest : public ::testing::Test {
converter_->AddTensorOrWeights(name, TRT_TensorOrWeights{weights}));
}
template <typename T>
template <typename T = int32>
void AddTestWeights(const string& name, const std::vector<int>& dims,
const std::vector<T>& values, DataType tf_dtype) {
if (tf_dtype == DT_FLOAT) {
const std::vector<T>& values, DataType tf_type) {
if (tf_type == DT_FLOAT) {
AddTestWeights(name, dims, CastTestVector<T, float>(values));
} else if (tf_dtype == DT_HALF) {
} else if (tf_type == DT_HALF) {
AddTestWeights(name, dims, CastTestVector<T, Eigen::half>(values));
} else if (tf_dtype == DT_INT32) {
} else if (tf_type == DT_INT32) {
AddTestWeights(name, dims, CastTestVector<T, int32>(values));
} else {
FAIL() << "Cannot create test weights with type "
<< DataTypeString(tf_dtype);
<< DataTypeString(tf_type);
}
}
// Test validation in validation-only mode.
void RunValidation(const Node* node, error::Code expected_code = error::OK,
const char* expected_msg_substr = nullptr) {
Status RunValidation(const Node* node) {
grappler::GrapplerItem item;
TF_EXPECT_OK(scope_.ToGraphDef(&item.graph));
grappler::GraphProperties graph_properties(item);
@ -1578,8 +1579,7 @@ class OpConverterTest : public ::testing::Test {
TrtNodeValidator validator(graph_properties, converter_->precision_mode(),
/*use_calibration=*/false,
converter_->use_implicit_batch());
ExpectStatus(validator.IsTensorRTCandidate(node), expected_code,
expected_msg_substr);
return validator.IsTensorRTCandidate(node);
}
void RunConversion(const Node* node, error::Code expected_code = error::OK,
@ -1610,9 +1610,11 @@ class OpConverterTest : public ::testing::Test {
graph->AddEdge(input.node(), input.index(), node, i);
}
RunValidation(node, expected_code, expected_msg_substr);
if (should_run_conversion) {
status = RunValidation(node);
if (should_run_conversion && status.ok()) {
RunConversion(node, expected_code, expected_msg_substr);
} else {
ExpectStatus(status, expected_code, expected_msg_substr);
}
}
@ -1717,7 +1719,7 @@ class ParameterizedOpConverterTestBase
public:
ParameterizedOpConverterTestBase()
: trt_mode(std::get<0>(GetParam())),
tf_dtype(std::get<1>(GetParam())),
tf_type(std::get<1>(GetParam())),
converter_precision(std::get<2>(GetParam())) {}
void Reset() {
@ -1744,11 +1746,15 @@ class ParameterizedOpConverterTestBase
// be empty, in that case the partial_input_shape will be set automatically
// depending on the trt_mode argument. (This argument also includes explicit
// batch dim).
// - add_input_status adding ITensor to the network can fail in implicit batch
// mode if the batch size is inconsistent. Using the add_input_status arg we
// can test such errors.
//
template <typename T>
template <typename T = int>
void AddTestTensor(const string& name, const std::vector<int32>& dims,
DataType tf_dtype, const std::vector<T>& values,
const std::vector<int32>& partial_input_shape_dims = {}) {
DataType tf_type, const std::vector<T>& values,
const std::vector<int32>& partial_input_shape_dims = {},
Status add_input_status = Status::OK()) {
std::vector<int32> partial_shape;
if (!partial_input_shape_dims.empty()) {
partial_shape = partial_input_shape_dims;
@ -1761,24 +1767,25 @@ class ParameterizedOpConverterTestBase
partial_shape = dims;
}
}
AddTestTensorWithTFDims(name, partial_shape, TfDataTypeToTrt(tf_dtype));
AddTestTensorWithTFDims(name, partial_shape, TfDataTypeToTrt(tf_type),
add_input_status);
if (!values.empty()) {
VLOG(2) << "Adding test tensor: " << name << " "
<< DataTypeString(tf_dtype);
InputOutputData data{name, AsTensor(values, dims, tf_dtype)};
<< DataTypeString(tf_type);
InputOutputData data{name, AsTensor(values, dims, tf_type)};
VLOG(2) << "Added tensor: " << data.name
<< DataTypeString(data.tensor.dtype());
input_data_.push_back(data);
}
}
// Adds test tensor (same as above) but with the default tf_dtype defined by
// Adds test tensor (same as above) but with the default tf_type defined by
// the test params.
template <typename T = int>
void AddTestTensor(const string& name, const std::vector<int32>& dims,
const std::vector<float>& values = {},
const std::vector<T>& values = {},
const std::vector<int32>& partial_input_shape_dims = {}) {
AddTestTensor<float>(name, dims, tf_dtype, values,
partial_input_shape_dims);
AddTestTensor<T>(name, dims, tf_type, values, partial_input_shape_dims);
}
// Builds and runs the converted network. Checks output tensor shape. Tests
@ -1797,7 +1804,7 @@ class ParameterizedOpConverterTestBase
TensorShapeUtils::MakeShape(expected_output_dims[i], &shape));
string out_name = (n_output == 1) ? name : StrCat(name, ":", i);
InputOutputData data{out_name,
ConstructTensor(shape.num_elements(), 0, tf_dtype)};
ConstructTensor(shape.num_elements(), 0, tf_type)};
output_data.push_back(data);
}
ASSERT_FALSE(input_data_.empty());
@ -1836,7 +1843,7 @@ class ParameterizedOpConverterTestBase
protected:
const TrtTestMode trt_mode;
const DataType tf_dtype;
const DataType tf_type;
const TrtPrecisionMode converter_precision;
DataVec input_data_;
};
@ -1869,6 +1876,15 @@ INSTANTIATE_TEST_CASE_P(
::testing::Combine(::testing::ValuesIn(ValidTrtModes),
::testing::Values(DT_FLOAT, DT_HALF),
::testing::Values(TrtPrecisionMode::FP32)));
// Base class for tests that need to be tested for FP32, FP16, and INT32
class OpConverterTest3 : public ParameterizedOpConverterTestBase {};
INSTANTIATE_TEST_CASE_P(
OpConvTestInstantiation3, OpConverterTest3,
::testing::Combine(::testing::ValuesIn(ValidTrtModes),
::testing::Values(DT_FLOAT, DT_HALF, DT_INT32),
::testing::Values(TrtPrecisionMode::FP32)));
template <typename T>
void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField<T>* out) {
out->Clear();
@ -2009,7 +2025,7 @@ TEST_F(OpConverterTest, ConvertConst) {
TEST_P(OpConverterTest1, ConvertTranspose) {
// Get the NodeDef for Transpose.
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype);
auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
auto transpose = ops::Transpose(s.WithOpName("my_transpose"), input, weights);
const NodeDef& node_def = transpose.operation.node()->def();
@ -2408,10 +2424,10 @@ TEST_P(OpConverterTest2, ConvertBiasAdd) {
// DT_INT32 type here. DT_FLOAT and DT_HALF are tested.
// Get the NodeDef for BiasAdd.
auto get_biasadd_nodedef = [](const string& data_format,
DataType tf_dtype) -> NodeDef {
DataType tf_type) -> NodeDef {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype);
auto weights = ops::Placeholder(s.WithOpName("weights"), tf_dtype);
auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
auto weights = ops::Placeholder(s.WithOpName("weights"), tf_type);
const auto biasadd_attrs = ops::BiasAdd::DataFormat(data_format);
auto biasadd =
ops::BiasAdd(s.WithOpName("my_biasadd"), input, weights, biasadd_attrs);
@ -2421,7 +2437,7 @@ TEST_P(OpConverterTest2, ConvertBiasAdd) {
for (const string& data_format : {"NHWC", "NCHW"}) {
for (const int trt_input_rank : {1, 2, 3, 4}) {
Reset();
NodeDef node_def = get_biasadd_nodedef(data_format, tf_dtype);
NodeDef node_def = get_biasadd_nodedef(data_format, tf_type);
// Add input, dims_array will be like {2, 1, ..., 1, 3}
std::vector<int32> dims_array(trt_input_rank + 1, 1);
@ -2443,7 +2459,7 @@ TEST_P(OpConverterTest2, ConvertBiasAdd) {
for (int i = 0; i < channel_size; ++i) {
bias[i] = i + 1; // bias will be {1, 2, 3, ...}
}
AddTestWeights("weights", {channel_size}, bias, tf_dtype);
AddTestWeights("weights", {channel_size}, bias, tf_type);
// Build and run the engine.
std::vector<float> output_data;
@ -2468,103 +2484,18 @@ TEST_P(OpConverterTest2, ConvertBiasAdd) {
}
template <typename OpType>
NodeDef GetBinaryOpNodeDef(const string& input_name_l,
const string& input_name_r, DataType dtype) {
NodeDef GetBinaryOpNodeDef(DataType dtype) {
Scope s = Scope::NewRootScope();
auto input_l = ops::Placeholder(s.WithOpName(input_name_l), dtype);
auto input_r = ops::Placeholder(s.WithOpName(input_name_r), dtype);
auto input_l = ops::Placeholder(s.WithOpName("input1"), dtype);
auto input_r = ops::Placeholder(s.WithOpName("input2"), dtype);
auto op = OpType(s.WithOpName("my_binary"), input_l, input_r);
return op.operation.node()->def();
}
template <typename OpType, DataType dtype>
void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor,
bool operand_2_is_tensor) {
typedef typename EnumToDataType<dtype>::Type CType;
test->Reset();
const NodeDef node_def =
GetBinaryOpNodeDef<OpType>("input1", "input2", dtype);
if (operand_1_is_tensor) {
test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/2,
TfDataTypeToTrt(dtype));
} else {
test->AddTestWeights("input1", /*dims=*/{1, 2},
/*values=*/std::vector<CType>{CType(3), CType(6)});
}
if (operand_2_is_tensor) {
test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/2,
TfDataTypeToTrt(dtype));
} else {
test->AddTestWeights("input2", /*dims=*/{2, 1},
/*values=*/std::vector<CType>{CType(2), CType(3)});
}
test->RunValidationAndConversion(node_def);
DataVec input_data;
if (operand_1_is_tensor) {
input_data.push_back(
{"input1",
test->AsTensor<CType>({CType(3), CType(6), CType(3), CType(6)})});
}
if (operand_2_is_tensor) {
input_data.push_back(
{"input2",
test->AsTensor<CType>({CType(2), CType(3), CType(2), CType(3)})});
}
DataVec output_data{{"my_binary", test->ConstructTensor<CType>(8)}};
// Check output dims.
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions());
// After broadcasting first input becomes {3, 6, 3, 6} and second input
// becomes {2, 3, 2, 3}.
TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data, /*batch_size=*/2));
if (node_def.op() == "Add") {
EXPECT_THAT(
GetSpanForData<CType>(output_data[0]),
ElementsAreArray(CastTestVector<int, CType>({5, 8, 6, 9, 5, 8, 6, 9})));
} else if (node_def.op() == "Sub") {
EXPECT_THAT(
GetSpanForData<CType>(output_data[0]),
ElementsAreArray(CastTestVector<int, CType>({1, 4, 0, 3, 1, 4, 0, 3})));
} else if (node_def.op() == "Mul") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAreArray(
CastTestVector<int, CType>({6, 12, 9, 18, 6, 12, 9, 18})));
} else if (node_def.op() == "Div") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAreArray(CastTestVector<float, CType>(
{1.5, 3, 1, 2, 1.5, 3, 1, 2})));
} else if (node_def.op() == "RealDiv") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAreArray(CastTestVector<float, CType>(
{1.5, 3, 1, 2, 1.5, 3, 1, 2})));
} else if (node_def.op() == "FloorDiv") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAreArray(
CastTestVector<float, CType>({1, 3, 1, 2, 1, 3, 1, 2})));
} else if (node_def.op() == "Minimum") {
EXPECT_THAT(
GetSpanForData<CType>(output_data[0]),
ElementsAreArray(CastTestVector<int, CType>({2, 2, 3, 3, 2, 2, 3, 3})));
} else if (node_def.op() == "Maximum") {
EXPECT_THAT(
GetSpanForData<CType>(output_data[0]),
ElementsAreArray(CastTestVector<int, CType>({3, 6, 3, 6, 3, 6, 3, 6})));
} else if (node_def.op() == "Pow") {
ExpectArrayNear(
CastTestVector<int, CType>({9, 36, 27, 216, 9, 36, 27, 216}),
GetSpanForData<CType>(output_data[0]));
} else {
ASSERT_TRUE(false);
}
}
TEST_F(OpConverterTest, ConvertBinary) {
AttrValue dtype;
dtype.set_type(DT_FLOAT);
TEST_P(OpConverterTest2, ConvertBinary) {
{
AttrValue dtype;
dtype.set_type(tf_type);
// Both inputs are weights.
Reset();
NodeDef node_def =
@ -2577,46 +2508,56 @@ TEST_F(OpConverterTest, ConvertBinary) {
"both input as constant at: my_add");
}
using OpFunc = std::function<NodeDef(DataType)>;
std::map<std::string, std::pair<OpFunc, std::vector<float>>> op_test_info;
#define ADD_OP(name, op, v1, v2, v3, v4, v5, v6, v7, v8) \
op_test_info[name] = \
std::make_pair(GetBinaryOpNodeDef<op>, \
std::vector<float>(v1, v2, v3, v4, v5, v6, v7, v8))
ADD_OP("Add", ops::Add, {5, 8, 6, 9, 5, 8, 6, 9});
ADD_OP("AddV2", ops::AddV2, {5, 8, 6, 9, 5, 8, 6, 9});
ADD_OP("Sub", ops::Sub, {1, 4, 0, 3, 1, 4, 0, 3});
ADD_OP("Mul", ops::Mul, {6, 12, 9, 18, 6, 12, 9, 18});
ADD_OP("Div", ops::Div, {1.5, 3, 1, 2, 1.5, 3, 1, 2});
ADD_OP("RealDiv", ops::RealDiv, {1.5, 3, 1, 2, 1.5, 3, 1, 2});
ADD_OP("FloorDiv", ops::FloorDiv, {1, 3, 1, 2, 1, 3, 1, 2});
ADD_OP("Minimum", ops::Minimum, {2, 2, 3, 3, 2, 2, 3, 3});
ADD_OP("Maximum", ops::Maximum, {3, 6, 3, 6, 3, 6, 3, 6});
ADD_OP("Pow", ops::Pow, {9, 36, 27, 216, 9, 36, 27, 216});
#undef ADD_OP
// Add all ops supported by ConvertBinary.
auto* supported_ops = BinaryOperationMap();
// Test combinations of tensor vs weight inputs (except when both inputs are
// weights).
for (const bool operand_1_is_tensor : {true, false}) {
for (const bool operand_2_is_tensor : {true, false}) {
if (!operand_1_is_tensor && !operand_2_is_tensor) continue;
// FP32 tests
TestBinaryOp<ops::Add, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Sub, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Mul, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Div, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::RealDiv, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Minimum, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Maximum, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Pow, DT_FLOAT>(this, operand_1_is_tensor,
operand_2_is_tensor);
// FP16 tests
// TODO(tmorris): Use templates to avoid duplication.
TestBinaryOp<ops::Add, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Sub, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Mul, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Div, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::RealDiv, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Minimum, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Maximum, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
TestBinaryOp<ops::Pow, DT_HALF>(this, operand_1_is_tensor,
operand_2_is_tensor);
for (auto& iter : *supported_ops) {
string op_name = iter.first;
SCOPED_TRACE(StrCat(op_name, "_", operand_1_is_tensor ? "T" : "W",
operand_2_is_tensor ? "T" : "W"));
Reset();
if (!op_test_info.count(op_name)) {
FAIL() << "Binary op test map does not contain op " << op_name;
}
NodeDef node_def = op_test_info[op_name].first(tf_type);
std::vector<std::string> input_names;
std::vector<std::vector<int>> input_dims;
std::vector<std::vector<float>> input_values;
if (operand_1_is_tensor) {
AddTestTensor("input1", {2, 1, 2}, {3, 6, 3, 6});
} else {
AddTestWeights("input1", {1, 2}, std::vector<float>{3, 6}, tf_type);
}
if (operand_2_is_tensor) {
AddTestTensor("input2", {2, 2, 1}, {2, 3, 2, 3});
} else {
AddTestWeights("input2", {2, 1}, std::vector<float>{2, 3}, tf_type);
}
TestOpConverter("my_binary", node_def, {2, 2, 2}, Status::OK(),
Status::OK(),
ElementsAreArray(op_test_info[op_name].second));
}
}
}
}
@ -2966,17 +2907,17 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) {
#endif // IS_TRT_VERSION_GE(5, 1, 0, 0)
template <typename T>
NodeDef CreateUnaryOp(DataType tf_dtype) {
NodeDef CreateUnaryOp(DataType tf_type) {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype);
auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
return T(s.WithOpName("my_unary"), input).operation.node()->def();
}
constexpr float kLeakyReluAlpha = 0.2f;
template <>
NodeDef CreateUnaryOp<ops::internal::LeakyRelu>(DataType tf_dtype) {
NodeDef CreateUnaryOp<ops::internal::LeakyRelu>(DataType tf_type) {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype);
auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
return ops::internal::LeakyRelu(
s.WithOpName("my_unary"), input,
ops::internal::LeakyRelu::Alpha(kLeakyReluAlpha))
@ -2988,7 +2929,7 @@ TEST_P(OpConverterTest1, ConvertActivation) {
{
// Input is weights, should fail.
Reset();
const NodeDef& node_def = CreateUnaryOp<ops::Relu>(tf_dtype);
const NodeDef& node_def = CreateUnaryOp<ops::Relu>(tf_type);
AddTestWeights<int32>("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
@ -3045,7 +2986,7 @@ TEST_P(OpConverterTest1, ConvertActivation) {
FAIL() << "Activation op test map does not contain op " << op_name;
}
Reset();
NodeDef node_def = op_map[op_name].first(tf_dtype);
NodeDef node_def = op_map[op_name].first(tf_type);
const std::vector<float> input = {-100, -2, -1, 0, 1, 88};
AddTestTensor("input", p.input_dims, input);
@ -3070,10 +3011,10 @@ TEST_P(OpConverterTest1, ConvertActivation) {
}
}
TEST_F(OpConverterTest, ConvertExpandDims) {
TEST_P(OpConverterTest1, ConvertExpandDims) {
// Get the NodeDef for ExpandDims.
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
auto expanddims =
ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights);
@ -3090,85 +3031,60 @@ TEST_F(OpConverterTest, ConvertExpandDims) {
{
// Axis is a tensor, should fail.
Reset();
AddTestTensor("input", {1, 2, 3});
AddTestTensor("input", {3, 2, 1});
AddTestTensor("weights", {3});
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
"The input \"axis\" for ExpandDims must be a "
"constant, at my_expanddims");
}
{
// Add dim at batch dimension, should fail.
Reset();
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("weights", {1}, {0});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"TensorRT does not allow manipulation of the batch dimension, at "
"my_expanddims");
}
{
// Add dim at batch dimension via negative axis, should fail.
Reset();
AddTestTensor("input", {1, 2, 3});
// Input is rank 4 (batch dim included)
AddTestWeights<int32>("weights", {1}, {-5});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"TensorRT does not allow manipulation of the batch dimension, at "
"my_expanddims");
}
{
// Axis > rank(input), should fail.
Reset();
AddTestTensor("input", {1, 2, 3});
// Input is rank 4 (batch dim included)
AddTestWeights<int32>("weights", {1}, {5});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"Axis value of 5 is out of bounds, must be in range [-5, 5), at "
"my_expanddims");
}
{
// Axis < -rank(input)-1, should fail.
Reset();
AddTestTensor("input", {1, 2, 3});
// Input is rank 4 (batch dim included)
AddTestWeights<int32>("weights", {1}, {-6});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"Axis value of -6 is out of bounds, must be in range [-5, 5), at "
"my_expanddims");
}
struct TestParams {
std::vector<int> input_dims;
int axis;
std::vector<int> expected_output_dims;
std::vector<TestParamBase> test_params = {
TestParamBase{{1, 1, 2, 3},
{},
{1, 1, 1, 2, 3},
{0},
trt_mode == TrtTestMode::kImplicitBatch
? Status(error::UNIMPLEMENTED,
"TensorRT does not allow manipulation of the "
"batch dimension, at my_expanddims")
: Status::OK()},
TestParamBase{{1, 1, 2, 3},
{},
{1, 1, 1, 2, 3},
{-5},
trt_mode == TrtTestMode::kImplicitBatch
? Status(error::UNIMPLEMENTED,
"TensorRT does not allow manipulation of the "
"batch dimension, at my_expanddims")
: Status::OK()},
TestParamBase{{1, 1, 2, 3},
{},
{},
{5},
Status(error::INVALID_ARGUMENT,
"Axis value of 5 is out of bounds, must be in range"
" [-5, 5), at my_expanddims")},
TestParamBase{{1, 1, 2, 3},
{},
{},
{-6},
Status(error::INVALID_ARGUMENT,
"Axis value of -6 is out of bounds, must be in range"
" [-5, 5), at my_expanddims")},
TestParamBase{{1, 2, 3}, {}, {1, 1, 2, 3}, {1}},
TestParamBase{{1, 2, 3}, {}, {1, 1, 2, 3}, {-3}},
TestParamBase{{1, 2, 3}, {}, {1, 2, 3, 1}, {3}},
TestParamBase{{1, 2, 3}, {}, {1, 2, 3, 1}, {-1}},
TestParamBase{{1, 2, 3}, {}, {1, 2, 1, 3}, {2}},
TestParamBase{{1, 2, 3}, {}, {1, 2, 1, 3}, {-2}},
TestParamBase{{1, 6}, {}, {1, 1, 6}, {1}},
TestParamBase{{1, 6}, {}, {1, 6, 1}, {-1}},
};
// Ok.
std::vector<TestParams> ok_params = {
TestParams{{2, 3}, 1, {1, 2, 3}}, TestParams{{2, 3}, -3, {1, 2, 3}},
TestParams{{2, 3}, 3, {2, 3, 1}}, TestParams{{2, 3}, -1, {2, 3, 1}},
TestParams{{2, 3}, 2, {2, 1, 3}}, TestParams{{2, 3}, -2, {2, 1, 3}},
TestParams{{6}, 1, {1, 6}}, TestParams{{6}, -1, {6, 1}},
};
for (int i = 0; i < ok_params.size(); ++i) {
for (auto p : test_params) {
Reset();
AddTestTensor("input", ok_params[i].input_dims);
AddTestWeights<int32>("weights", {1}, {ok_params[i].axis});
RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(GetTensorOrWeights("my_expanddims", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
output.tensor()->getDimensions());
const DataVec input_data{{"input", AsTensor<float>({1, 2, 3, 4, 5, 6})}};
DataVec output_data{{"my_expanddims", ConstructTensor<float>(6)}};
TF_EXPECT_OK(BuildAndRun(input_data, &output_data));
EXPECT_THAT(GetSpanForData<float>(output_data[0]),
ElementsAre(1, 2, 3, 4, 5, 6));
AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6});
AddTestWeights<int32>("weights", {1}, {p.param[0]});
TestOpConverter("my_expanddims", node_def, p.expected_output_dims, p.status,
p.runtime_status, ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
}
@ -3176,9 +3092,9 @@ TEST_P(OpConverterTest1, ConvertSqueeze) {
const bool use_implicit_batch = (trt_mode == TrtTestMode::kImplicitBatch);
// Get the NodeDef for Squeeze.
auto get_squeeze_nodedef = [](std::vector<int> axes,
DataType tf_dtype) -> NodeDef {
DataType tf_type) -> NodeDef {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype);
auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
if (!axes.empty()) {
ops::Squeeze::Attrs squeeze_attrs;
squeeze_attrs.axis_ = gtl::ArraySlice<int>(axes); // non-absl ok
@ -3270,7 +3186,7 @@ TEST_P(OpConverterTest1, ConvertSqueeze) {
for (TestParamBase p : test_params) {
SCOPED_TRACE(p);
Reset();
NodeDef node_def = get_squeeze_nodedef(p.param, tf_dtype);
NodeDef node_def = get_squeeze_nodedef(p.param, tf_type);
AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6},
p.partial_input_dims);
TestOpConverter("my_squeeze", node_def, p.expected_output_dims, p.status,
@ -4025,14 +3941,14 @@ TEST_F(OpConverterTest, ConvertSlice) {
TEST_P(OpConverterTest1, ConvertConv2D) {
// Get nodedef for Conv2D layer.
DataType tf_type = tf_dtype;
DataType tf_type_loc = tf_type;
auto get_conv2d_nodedef =
[tf_type](std::vector<int> strides = {1, 1, 1, 1},
string padding = "SAME", string data_format = "NCHW",
std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef {
[tf_type_loc](std::vector<int> strides = {1, 1, 1, 1},
string padding = "SAME", string data_format = "NCHW",
std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type);
auto input = ops::Placeholder(s.WithOpName("input"), tf_type_loc);
auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type_loc);
ops::Conv2D::Attrs attrs =
ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations);
auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides,
@ -4233,8 +4149,8 @@ TEST_P(OpConverterTest1, ConvertConv2D) {
partial_input_shape[channel_id] = ok_params[i].input_dims[channel_id];
}
AddTestTensor("input", ok_params[i].input_dims, tf_dtype,
ok_params[i].input, partial_input_shape);
AddTestTensor("input", ok_params[i].input_dims, tf_type, ok_params[i].input,
partial_input_shape);
AddTestWeights<float>("weights", ok_params[i].filter_dims,
ok_params[i].filter);
@ -4938,17 +4854,34 @@ TEST_F(OpConverterTest, ConvertTopK) {
}
}
template <DataType dtype>
void TestConvertGather(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType;
TEST_P(OpConverterTest3, ConvertGather) {
// Get the NodeDef for GatherV2.
Scope s = Scope::NewRootScope();
auto params = ops::Placeholder(s.WithOpName("params"), dtype);
auto params = ops::Placeholder(s.WithOpName("params"), tf_type);
auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis);
const NodeDef& node_def = gather.operation.node()->def();
{
// Axis is a tensor, should fail.
Reset();
AddTestTensor("params", {1, 1, 2, 3}, tf_type, {});
AddTestTensor("indices", {1, 2}, DT_INT32, {});
AddTestTensor("axis", {1}, DT_INT32, {});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"The input \"axis\" for GatherV2 must be a constant, at my_gather");
}
{
// Axis is out of bounds, should fail.
Reset();
AddTestTensor("params", {1, 1, 2, 3});
AddTestTensor("indices", {1, 2}, DT_INT32, {});
AddTestWeights<int32>("axis", {1}, {4});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Axis value of 4 is out of bounds, must be in "
"range [-4, 4), at my_gather");
}
struct TestParams {
// TF shape of the input 'params' (including batch dimension).
@ -4961,12 +4894,74 @@ void TestConvertGather(OpConverterTest* test) {
std::vector<int> expected_output_shape;
std::vector<int> expected_output;
bool params_is_tensor;
Status status;
Status runtime_status;
Status add_index_status;
};
// Input is the same {1, 2, 3, 4, 5, 6} for all cases.
const std::vector<CType> params_input = {CType(1), CType(2), CType(3),
CType(4), CType(5), CType(6)};
std::vector<TestParams> ok_params = {
const std::vector<int> params_input = {1, 2, 3, 4, 5, 6};
std::vector<TestParams> test_params = {
// Axis is batch dimension, should fail in implicit batch mode.
TestParams{/*params_shape=*/{2, 1, 1, 3},
/*indices_shape=*/{2},
/*indices=*/{1, 0},
/*axis=*/0,
/*expected_output_shape=*/{2, 1, 1, 3},
/*expected_output=*/{4, 5, 6, 1, 2, 3},
/*params_is_tensor=*/true,
trt_mode == TrtTestMode::kImplicitBatch
? Status{error::UNIMPLEMENTED,
"TensorRT does not allow manipulation of the"
" batch dimension, at my_gather"}
: Status::OK()},
// Batch size of indices is not 1 when params is a tensor.
TestParams{/*params_shape=*/{2, 1, 3},
/*indices_shape=*/{2, 1},
/*indices=*/{2, 0},
/*axis=*/2,
/*expected_output_shape=*/{2, 1, 2, 1},
/*expected_output=*/{3, 1, 6, 4},
/*params_is_tensor=*/true,
trt_mode == TrtTestMode::kImplicitBatch
? Status{error::UNIMPLEMENTED,
"Indices must have a batch size of 1 when params"
" is a tensor."}
: Status::OK()},
// Axis is not zero when params is a weight, should fail in implicit batch
// mode.
TestParams{/*params_shape=*/{2, 1, 3},
/*indices_shape=*/{2},
/*indices=*/{1, 2},
/*axis=*/2,
/*expected_output_shape=*/{2, 1, 2},
/*expected_output=*/{2, 3, 5, 6},
/*params_is_tensor=*/false,
trt_mode == TrtTestMode::kImplicitBatch
? Status{error::UNIMPLEMENTED,
"The input axis must be zero when params is a"
" weight."}
: Status::OK()},
// Params with only batch dimension.
TestParams{/*params_shape=*/{6},
/*indices_shape=*/{2},
/*indices=*/{1, 3},
/*axis=*/0,
/*expected_output_shape=*/{2},
/*expected_output=*/{2, 4},
/*params_is_tensor=*/true,
trt_mode == TrtTestMode::kImplicitBatch // conversion_status
? Status{error::UNIMPLEMENTED,
"TensorRT does not allow manipulation of the "
"batch dimension, at my_gather"}
: Status::OK(),
Status::OK(), // runtime_status
trt_mode == TrtTestMode::kImplicitBatch // add_index_status
? Status{error::INVALID_ARGUMENT,
"Batch size doesn't match for tensor indices: "
"Provided batch size does not match converter "
"batch size: 2 vs 6"}
: Status::OK()},
// Vector indices, and output rank is rank(params).
TestParams{
/*params_shape=*/{1, 1, 2, 3},
@ -4986,7 +4981,8 @@ void TestConvertGather(OpConverterTest* test) {
/*expected_output=*/{4, 5, 6},
/*params_is_tensor=*/true,
},
// Indices with rank>1, and output rank is rank(params)+rank(indices)-1.
// Indices with rank>1, and output rank is rank(params) + rank(indices) -
// 1
TestParams{
/*params_shape=*/{1, 1, 2, 3},
/*indices_shape=*/{1, 1},
@ -5070,125 +5066,22 @@ void TestConvertGather(OpConverterTest* test) {
},
};
// Ok.
for (int i = 0; i < ok_params.size(); i++) {
test->Reset();
const auto& params_shape = ok_params[i].params_shape;
if (ok_params[i].params_is_tensor) {
std::vector<int> params_dims(params_shape.begin() + 1,
params_shape.end());
test->AddTestTensor("params", params_dims, params_shape[0],
TfDataTypeToTrt(dtype));
for (auto p : test_params) {
Reset();
if (p.params_is_tensor) {
AddTestTensor("params", p.params_shape, params_input);
} else {
test->AddTestWeights<CType>("params", params_shape, params_input);
AddTestWeights("params", p.params_shape, params_input, tf_type);
}
const auto& indices_shape = ok_params[i].indices_shape;
test->AddTestTensor(
"indices",
std::vector<int>(indices_shape.begin() + 1, indices_shape.end()),
indices_shape[0], nvinfer1::DataType::kINT32);
test->AddTestWeights<int32>("axis", {1}, {ok_params[i].axis});
test->RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_gather", &output));
ASSERT_TRUE(output.is_tensor());
const auto& expected_output_shape = ok_params[i].expected_output_shape;
const auto& expected_output = ok_params[i].expected_output;
ASSERT_EQ(expected_output.size(),
TrtWeightDimsNumElements(GetTestDims(expected_output_shape)));
const std::vector<int> expected_output_dims(
expected_output_shape.begin() + 1, expected_output_shape.end());
ExpectTrtDimsEqualsArray(expected_output_dims,
output.tensor()->getDimensions());
// Create input in CType and convert expected output to CType.
std::vector<CType> converted_expected_output(expected_output.begin(),
expected_output.end());
DataVec input_data;
if (ok_params[i].params_is_tensor) {
input_data = {{"params", test->AsTensor<CType>(params_input)},
{"indices", test->AsTensor<int32>(ok_params[i].indices)}};
} else {
input_data = {{"indices", test->AsTensor<int32>(ok_params[i].indices)}};
}
DataVec output_data{
{"my_gather", test->ConstructTensor<CType>(expected_output.size())}};
TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data,
/*batch_size=*/expected_output_shape[0]));
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAreArray(converted_expected_output));
AddTestTensor("indices", p.indices_shape, DT_INT32, p.indices, {},
p.add_index_status);
AddTestWeights<int32>("axis", {1}, {p.axis});
TestOpConverter("my_gather", node_def, p.expected_output_shape, p.status,
p.runtime_status, ElementsAreArray(p.expected_output));
}
}
TEST_F(OpConverterTest, ConvertGather) {
// Get the NodeDef for GatherV2.
Scope s = Scope::NewRootScope();
auto params = ops::Placeholder(s.WithOpName("params"), DT_FLOAT);
auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis);
const NodeDef& node_def = gather.operation.node()->def();
{
// Axis is a tensor, should fail.
Reset();
AddTestTensor("params", {1, 2, 3});
AddTestTensor("indices", {2});
AddTestTensor("axis", {1});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"The input \"axis\" for GatherV2 must be a constant, at my_gather");
}
{
// Axis is out of bounds, should fail.
Reset();
AddTestTensor("params", {1, 2, 3});
AddTestTensor("indices", {2});
AddTestWeights<int32>("axis", {1}, {4});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Axis value of 4 is out of bounds, must be in "
"range [-4, 4), at my_gather");
}
{
// Axis is batch dimension, should fail.
Reset();
AddTestTensor("params", {1, 2, 3});
AddTestTensor("indices", {2});
AddTestWeights<int32>("axis", {1}, {0});
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
"TensorRT does not allow manipulation of the "
"batch dimension, at my_gather");
}
{
// Axis is not zero when params is a weight, should fail.
Reset();
AddTestWeights<int32>("params", {1, 3}, {1, 2, 3});
AddTestTensor("indices", {2});
AddTestWeights<int32>("axis", {1}, {1});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"The input axis must be zero when params is a weight.");
}
{
// Batch size of indices is not 1 when params is a tensor.
Reset();
AddTestTensor("params", {1, 2, 3}, /*batch_size=*/2);
AddTestTensor("indices", {2}, /*batch_size=*/2);
AddTestWeights<int32>("axis", {1}, {1});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"Indices must have a batch size of 1 when params is a tensor.");
}
Reset();
TestConvertGather<DT_FLOAT>(this);
TestConvertGather<DT_HALF>(this);
TestConvertGather<DT_INT32>(this);
}
NodeDef CreateCastOp(DataType tf_dtype) {
NodeDef CreateCastOp(DataType tf_type) {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), DT_HALF);
return ops::Cast(s.WithOpName("my_unary"), input, DT_FLOAT)
@ -5200,7 +5093,7 @@ TEST_P(OpConverterTest1, ConvertUnary) {
{
// Input is weights, should fail.
Reset();
const NodeDef node_def = CreateUnaryOp<ops::Neg>(tf_dtype);
const NodeDef node_def = CreateUnaryOp<ops::Neg>(tf_type);
AddTestWeights<float>("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
@ -5256,7 +5149,7 @@ TEST_P(OpConverterTest1, ConvertUnary) {
if (!op_map.count(op_name)) {
FAIL() << "Unary op test map does not contain op " << op_name;
}
NodeDef node_def = op_map[op_name].first(tf_dtype);
NodeDef node_def = op_map[op_name].first(tf_type);
// TODO(bixia): we assume this test is only instantiated for DT_FLOAT for
// now. Need to find a better way to express input and output types.
@ -5264,10 +5157,10 @@ TEST_P(OpConverterTest1, ConvertUnary) {
// TODO(tfeher): improve tests by defining an expected output data type and
// check that. Currently only the shape and values of the output are
// checked.
DataType input_tf_dtype = op_name == "Cast" ? DT_HALF : tf_dtype;
DataType input_tf_type = op_name == "Cast" ? DT_HALF : tf_type;
std::vector<float> input_values{-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f};
AddTestTensor("input", p.input_dims, input_tf_dtype, input_values);
AddTestTensor("input", p.input_dims, input_tf_type, input_values);
std::vector<float> output;
std::transform(input_values.begin(), input_values.end(),
std::back_inserter(output), op_map[op_name].second);

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "absl/algorithm/container.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
@ -88,6 +89,10 @@ inline bool HasStaticShape(const nvinfer1::Dims& dims) {
return true;
}
inline bool HasStaticShape(std::vector<int> dims) {
return !absl::c_any_of(dims, [](int i) { return i < 0; });
}
template <typename TensorShapeType>
inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
bool ignore_first_dim) {

View File

@ -182,7 +182,7 @@ respect to `operand`, `offset` and `scale` across all the other dimensions. The
`feature_index` must be a valid index for the feature dimension in `operand`.
The three gradients are defined by the following formulas (assuming a
4-dimensional array as `operand` and with feature dimension index $$l$$, batch
4-dimensional array as `operand` and with feature dimension index `l`, batch
size `m` and spatial sizes `w` and `h`):
\\[ \begin{split} c_l&=
@ -618,36 +618,44 @@ See also
<b> `Conditional(pred, true_operand, true_computation, false_operand,
false_computation)` </b>
<!-- mdformat off(disable mdformat for proper MathJax formatting) -->
Arguments | Type | Semantics
------------------- | ---------------- | --------------------------------------
------------------- | ---------------- | ------------------------------------
`pred` | `XlaOp` | Scalar of type `PRED`
`true_operand` | `XlaOp` | Argument of type $$ T_0 $$
`true_computation` | `XlaComputation` | XlaComputation of type $$ T_0 \to S$$
`false_operand` | `XlaOp` | Argument of type $$ T_1 $$
`false_computation` | `XlaComputation` | XlaComputation of type $$ T_1 \to S $$
`true_operand` | `XlaOp` | Argument of type \\(T_0\\)
`true_computation` | `XlaComputation` | XlaComputation of type \\(T_0 \to S\\)
`false_operand` | `XlaOp` | Argument of type \\(T_1\\)
`false_computation` | `XlaComputation` | XlaComputation of type \\(T_1 \to S\\)
Executes `true_computation` if `pred` is `true`, `false_computation` if `pred`
is `false`, and returns the result.
The `true_computation` must take in a single argument of type $$ T_0 $$ and will
The `true_computation` must take in a single argument of type \\(T_0\\) and will
be invoked with `true_operand` which must be of the same type. The
`false_computation` must take in a single argument of type $$ T_1 $$ and will be
`false_computation` must take in a single argument of type \\(T_1\\) and will be
invoked with `false_operand` which must be of the same type. The type of the
returned value of `true_computation` and `false_computation` must be the same.
<!-- mdformat on -->
Note that only one of `true_computation` and `false_computation` will be
executed depending on the value of `pred`.
<b> `Conditional(branch_index, branch_computations, branch_operands)` </b>
<!-- mdformat off(disable mdformat for proper MathJax formatting) -->
| Arguments | Type | Semantics |
| --------------------- | --------------------- | ---------------------------- |
| `branch_index` | `XlaOp` | Scalar of type `S32` |
| `branch_computations` | sequence of N | XlaComputations of type $$ |
| `branch_computations` | sequence of N | XlaComputations of type \\( |
: : `XlaComputation` : T_0 \to S , T_1 \to S , ..., :
: : : T_{N-1} \to S $$ :
| `branch_operands` | sequence of N `XlaOp` | Arguments of type $$ T_0 , |
: : : T_1 , ..., T_{N-1} $$ :
: : : T_{N-1} \to S \\) :
| `branch_operands` | sequence of N `XlaOp` | Arguments of type \\( T_0 , |
: : : T_1 , ..., T_{N-1} \\) :
<!-- mdformat on -->
Executes `branch_computations[branch_index]`, and returns the result. If
`branch_index` is an `S32` which is < 0 or >= N, then `branch_computations[N-1]`
@ -803,15 +811,15 @@ Here is pseudo-code for a 2d convolution with padding and striding:
```
for (b, oz, oy, ox) { // output coordinates
value = 0;
for (iz, ky, kx) { // kernel coordinates and input z
iy = oy*stride_y + ky - pad_low_y;
ix = ox*stride_x + kx - pad_low_x;
if ((iy, ix) inside the base area considered without padding) {
value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
}
}
output(b, oz, oy, ox) = value;
value = 0;
for (iz, ky, kx) { // kernel coordinates and input z
iy = oy*stride_y + ky - pad_low_y;
ix = ox*stride_x + kx - pad_low_x;
if ((iy, ix) inside the base area considered without padding) {
value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
}
}
output(b, oz, oy, ox) = value;
}
```
@ -891,19 +899,19 @@ Here is an example of an implementation of `myfunc`:
```
extern "C" void myfunc(void* out, void** in) {
float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
EXPECT_EQ(1, x[0]);
EXPECT_EQ(2, x[1]);
EXPECT_EQ(10, y[0][0]);
EXPECT_EQ(20, y[0][1]);
EXPECT_EQ(30, y[0][2]);
EXPECT_EQ(40, y[1][0]);
EXPECT_EQ(50, y[1][1]);
EXPECT_EQ(60, y[1][2]);
float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
z[0][0] = x[1] + y[1][0];
// ...
float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
EXPECT_EQ(1, x[0]);
EXPECT_EQ(2, x[1]);
EXPECT_EQ(10, y[0][0]);
EXPECT_EQ(20, y[0][1]);
EXPECT_EQ(30, y[0][2]);
EXPECT_EQ(40, y[1][0]);
EXPECT_EQ(50, y[1][1]);
EXPECT_EQ(60, y[1][2]);
float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
z[0][0] = x[1] + y[1][0];
// ...
}
```
@ -1686,11 +1694,11 @@ dependency between the while loops.
```
result1 = while (condition, init = init_value) {
Infeed(shape)
Infeed(shape)
}
result2 = while (condition, init = result1) {
Infeed(shape)
Infeed(shape)
}
```

View File

@ -158,6 +158,20 @@ cc_library(
],
)
cc_library(
name = "interpreter_device",
srcs = ["interpreter_device.cc"],
hdrs = ["interpreter_device.h"],
deps = [
":pjrt_client",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/compiler/xla/service:platform_util",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "cpu_device",
srcs = ["cpu_device.cc"],

View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
namespace xla {
static const char kInterpreterPlatformName[] = "interpreter";
InterpreterDevice::InterpreterDevice(
int id, std::unique_ptr<LocalDeviceState> local_device_state)
: Device(id, std::move(local_device_state), kInterpreterPlatformName,
/*device_kind=*/kInterpreterPlatformName) {}
StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient() {
TF_ASSIGN_OR_RETURN(se::Platform * platform,
PlatformUtil::GetPlatform("Interpreter"));
if (platform->VisibleDeviceCount() != 1) {
return FailedPrecondition(
"Interpreter platform should have exactly one device.");
}
LocalClientOptions options;
options.set_platform(platform);
TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options));
std::vector<std::unique_ptr<Device>> devices;
se::StreamExecutor* executor =
client->backend().stream_executor(0).ValueOrDie();
auto device_state = absl::make_unique<LocalDeviceState>(
executor, client, LocalDeviceState::kSynchronous, /*asynchronous=*/false,
/*allow_event_reuse=*/false);
auto device =
absl::make_unique<InterpreterDevice>(0, std::move(device_state));
devices.push_back(std::move(device));
return std::make_shared<PjRtClient>(
kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*gpu_run_options=*/nullptr);
}
} // namespace xla

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_INTERPRETER_DEVICE_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_INTERPRETER_DEVICE_H_
#include <memory>
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
class InterpreterDevice : public Device {
public:
InterpreterDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state);
};
StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient();
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PJRT_INTERPRETER_DEVICE_H_

View File

@ -187,6 +187,7 @@ PjRtClient::PjRtClient(
CHECK(local_devices_[idx] == nullptr) << idx;
local_devices_[idx] = device.get();
}
device->client_ = this;
}
for (int idx = 0; idx < local_devices_.size(); ++idx) {
CHECK(local_devices_[idx] != nullptr) << idx;

View File

@ -47,6 +47,8 @@ limitations under the License.
namespace xla {
class PjRtClient;
class Device {
public:
explicit Device(int id, std::unique_ptr<LocalDeviceState> local_device_state,
@ -86,12 +88,17 @@ class Device {
virtual std::string DebugString() const;
PjRtClient* client() const { return client_; }
private:
friend class PjRtClient;
const int id_;
const std::unique_ptr<LocalDeviceState> local_device_state_;
const int host_id_;
const std::string platform_name_;
const std::string device_kind_;
PjRtClient* client_ = nullptr;
};
// Forward declaration.
@ -113,7 +120,7 @@ using PjRtCrossHostRecvNotifier =
//
// It is the responsibility of the client of this API to keep the PjRtClient
// alive as long as any of the other runtime objects are alive.
class PjRtClient : public std::enable_shared_from_this<PjRtClient> {
class PjRtClient {
public:
// `allocator` may null, in which case the platform default allocator is used.
explicit PjRtClient(

View File

@ -123,6 +123,25 @@ cc_library(
],
)
cc_library(
name = "traceback",
srcs = ["traceback.cc"],
hdrs = ["traceback.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":python_ref_manager",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@pybind11",
],
)
cc_library(
name = "bfloat16",
srcs = ["bfloat16.cc"],
@ -159,6 +178,37 @@ py_test(
] + xla_py_test_deps(),
)
cc_library(
name = "py_client",
srcs = [
"py_buffer.cc",
"py_client.cc",
"py_executable.cc",
],
hdrs = [
"py_buffer.h",
"py_client.h",
"py_executable.h",
],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":python_ref_manager",
":traceback",
":types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@pybind11",
],
)
cc_library(
name = "dlpack",
srcs = ["dlpack.cc"],
@ -169,6 +219,8 @@ cc_library(
],
features = ["-use_header_modules"],
deps = [
":py_client",
":traceback",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
@ -233,7 +285,7 @@ cc_library(
)
tf_cc_test(
name = "cpu_outfeed_receiver_test",
name = "outfeed_receiver_test_cpu",
size = "small",
srcs = ["outfeed_receiver_test.cc"],
deps = [
@ -261,9 +313,11 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":outfeed_receiver",
":py_client",
":types",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"@pybind11",
@ -290,8 +344,10 @@ pybind_extension(
":bfloat16",
":dlpack",
":ops",
":py_client",
":python_ref_manager",
":outfeed_receiver_py",
":traceback",
":types",
"@com_google_absl//absl/base",
"@com_google_absl//absl/hash",
@ -307,6 +363,7 @@ pybind_extension(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:client_library",
@ -315,6 +372,7 @@ pybind_extension(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/pjrt:cpu_device",
"//tensorflow/compiler/xla/pjrt:interpreter_device",
"//tensorflow/compiler/xla/pjrt:nvidia_gpu_device",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/compiler/xla/pjrt:tracked_device_buffer",

View File

@ -23,7 +23,9 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "include/dlpack/dlpack.h" // from @dlpack
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/python/traceback.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
@ -239,12 +241,12 @@ StatusOr<Device*> DeviceForDLContext(const PjRtClient& client,
} // namespace
StatusOr<py::capsule> BufferToDLPackManagedTensor(PjRtBuffer* buffer) {
auto pack = absl::make_unique<DLPackTensor>();
StatusOr<py::capsule> BufferToDLPackManagedTensor(PyBuffer* buffer) {
auto pack = std::make_unique<DLPackTensor>();
// Block on outstanding operations, so that it is safe to read or mutate the
// returned buffer.
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or =
buffer->Release(/*wait_for_operations_to_complete=*/true);
buffer->buffer()->Release(/*wait_for_operations_to_complete=*/true);
if (!buffer_or.ok()) {
return InvalidArgument(
"Buffer synchronization failed converting to DLPack tensor: %s",
@ -258,22 +260,25 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(PjRtBuffer* buffer) {
pack->tensor.manager_ctx = pack.get();
pack->tensor.deleter = DLPackTensorDeleter;
DLTensor& dt = pack->tensor.dl_tensor;
if (buffer->on_device_shape().IsTuple()) {
if (buffer->buffer()->on_device_shape().IsTuple()) {
return Unimplemented(
"unsafe_buffer_pointer is not implemented for tuple "
"buffers.");
}
TF_RET_CHECK(pack->buffer->device_memory().size() == 1);
dt.data = pack->buffer->device_memory().front().opaque();
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->device()));
dt.ctx.device_id = buffer->device()->local_device_state()->device_ordinal();
dt.ndim = buffer->on_host_shape().dimensions_size();
TF_ASSIGN_OR_RETURN(dt.dtype, PrimitiveTypeToDLDataType(
buffer->on_host_shape().element_type()));
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
dt.ctx.device_id =
buffer->buffer()->device()->local_device_state()->device_ordinal();
dt.ndim = buffer->buffer()->on_host_shape().dimensions_size();
TF_ASSIGN_OR_RETURN(dt.dtype,
PrimitiveTypeToDLDataType(
buffer->buffer()->on_host_shape().element_type()));
pack->shape = std::vector<int64>(buffer->on_host_shape().dimensions().begin(),
buffer->on_host_shape().dimensions().end());
pack->strides = StridesForShape(buffer->on_host_shape());
pack->shape =
std::vector<int64>(buffer->buffer()->on_host_shape().dimensions().begin(),
buffer->buffer()->on_host_shape().dimensions().end());
pack->strides = StridesForShape(buffer->buffer()->on_host_shape());
dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
dt.byte_offset = 0;
@ -293,8 +298,8 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(PjRtBuffer* buffer) {
return capsule;
}
StatusOr<std::unique_ptr<PjRtBuffer>> DLPackManagedTensorToBuffer(
const pybind11::capsule& tensor, PjRtClient* client) {
StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
const pybind11::capsule& tensor, std::shared_ptr<PyClient> client) {
if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
return InvalidArgument(
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
@ -307,8 +312,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> DLPackManagedTensorToBuffer(
"Number of dimensions in DLManagedTensor must be nonnegative, got %d",
dlmt->dl_tensor.ndim);
}
TF_ASSIGN_OR_RETURN(Device * device,
DeviceForDLContext(*client, dlmt->dl_tensor.ctx));
TF_ASSIGN_OR_RETURN(
Device * device,
DeviceForDLContext(*client->pjrt_client(), dlmt->dl_tensor.ctx));
absl::Span<int64 const> dimensions(
reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
@ -344,8 +350,10 @@ StatusOr<std::unique_ptr<PjRtBuffer>> DLPackManagedTensorToBuffer(
// capsule it cannot be used again.
PyCapsule_SetName(tensor.ptr(), "used_dltensor");
PyCapsule_SetDestructor(tensor.ptr(), nullptr);
return absl::make_unique<PjRtBuffer>(shape, shape, std::move(device_buffer),
client, device);
auto pjrt_buffer = std::make_unique<PjRtBuffer>(
shape, shape, std::move(device_buffer), client->pjrt_client(), device);
return std::make_unique<PyBuffer>(std::move(client), std::move(pjrt_buffer),
Traceback::Get());
}
} // namespace xla

View File

@ -17,14 +17,15 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_
#include "pybind11/pybind11.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/py_buffer.h"
#include "tensorflow/compiler/xla/python/py_client.h"
namespace xla {
StatusOr<pybind11::capsule> BufferToDLPackManagedTensor(PjRtBuffer* buffer);
StatusOr<pybind11::capsule> BufferToDLPackManagedTensor(PyBuffer* buffer);
StatusOr<std::unique_ptr<PjRtBuffer>> DLPackManagedTensorToBuffer(
const pybind11::capsule& tensor, PjRtClient* client);
StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
const pybind11::capsule& tensor, std::shared_ptr<PyClient> client);
} // namespace xla

View File

@ -98,23 +98,17 @@ uint32_t constexpr kOutfeedHeaderStart = 271828;
// Special consumer IDs, without outfeed payload.
uint32_t constexpr kOutfeedCidShutdown = 0;
// A Device and its PjRtClient.
struct DeviceWithClient {
Device* device;
std::shared_ptr<PjRtClient> client;
};
// Encapsulates data received from a device outfeed.
class OutfeedData {
public:
OutfeedData(DeviceWithClient device_client, uint32_t consumer_id, Shape shape)
: device_client_(device_client),
OutfeedData(Device* device, uint32_t consumer_id, Shape shape)
: device_(device),
consumer_id_(consumer_id),
shape_(shape),
literal_(nullptr),
literal_size_bytes_(0) {}
DeviceWithClient device_client() { return device_client_; }
Device* device() { return device_; }
uint32_t consumer_id() const { return consumer_id_; }
Shape shape() const { return shape_; }
std::unique_ptr<Literal> literal() {
@ -129,7 +123,7 @@ class OutfeedData {
std::string DebugString() const;
private:
DeviceWithClient device_client_;
Device* device_;
uint32_t consumer_id_;
Shape shape_;
std::unique_ptr<Literal> literal_;
@ -150,15 +144,14 @@ void OutfeedData::SetLiteral(std::unique_ptr<Literal> literal) {
}
std::string OutfeedData::DebugString() const {
return absl::StrFormat("dev=%s; cons=%d; shape=%s",
device_client_.device->DebugString(), consumer_id_,
shape_.ToString());
return absl::StrFormat("dev=%s; cons=%d; shape=%s", device_->DebugString(),
consumer_id_, shape_.ToString());
}
class OutfeedReceiverImpl {
public:
OutfeedReceiverImpl(OutfeedReceiver::Callback callback,
std::vector<std::shared_ptr<PjRtClient>> clients,
absl::Span<PjRtClient* const> clients,
ssize_t max_callback_queue_size_bytes);
OutfeedReceiverImpl(const OutfeedReceiverImpl&) = delete;
@ -206,8 +199,8 @@ class OutfeedReceiverImpl {
void Shutdown();
OutfeedReceiver::Callback callback_;
// The devices on which we are listening, with their clients.
std::vector<DeviceWithClient> devices_;
// The devices on which we are listening.
std::vector<Device*> devices_;
// Maximum bytes capacity of the callback queue.
uint64_t max_callback_queue_size_bytes_;
@ -232,14 +225,13 @@ class OutfeedReceiverImpl {
};
OutfeedReceiverImpl::OutfeedReceiverImpl(
OutfeedReceiver::Callback callback,
std::vector<std::shared_ptr<PjRtClient>> clients,
OutfeedReceiver::Callback callback, absl::Span<PjRtClient* const> clients,
ssize_t max_callback_queue_size_bytes) {
callback_ = callback;
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
for (const auto& client : clients) {
for (const auto& device : client->devices()) {
devices_.push_back(DeviceWithClient{device.get(), client});
devices_.push_back(device.get());
}
}
CHECK_GT(devices_.size(), 0);
@ -291,11 +283,11 @@ void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
absl::MutexLock lock(&mu_);
++num_listening_threads_;
}
DeviceWithClient device_client = devices_[device_idx];
Device* device = devices_[device_idx];
while (true) {
Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords});
std::unique_ptr<Literal> header =
ReceiveRawFromOutfeed(device_client.device, header_shape).ValueOrDie();
ReceiveRawFromOutfeed(device, header_shape).ValueOrDie();
absl::Span<uint32_t> header_data = header->data<uint32>();
CHECK_EQ(header_data.size(), kOutfeedHeaderWords);
CHECK_EQ(header_data[0], kOutfeedHeaderStart);
@ -306,18 +298,17 @@ void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
auto registered_shape = shape_registry_.find(consumer_id);
if (registered_shape == shape_registry_.end()) {
LOG(FATAL)
<< "[" << device_client.device->DebugString()
<< "[" << device->DebugString()
<< "] Cannot find registered shape for consumer ID " << consumer_id
<< ". Perhaps the code was compiled with a different instance "
<< "of OutfeedReceiver.";
}
shape = registered_shape->second;
}
auto received =
absl::make_unique<OutfeedData>(device_client, consumer_id, shape);
auto received = absl::make_unique<OutfeedData>(device, consumer_id, shape);
VLOG(2) << "Listener received header " << received->DebugString();
if (consumer_id == kOutfeedCidShutdown) {
VLOG(2) << "[" << device_client.device->DebugString()
VLOG(2) << "[" << device->DebugString()
<< "] Listener received shutdown header";
absl::MutexLock lock(&mu_);
--num_listening_threads_;
@ -328,7 +319,7 @@ void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
return;
}
std::unique_ptr<Literal> data =
ReceiveRawFromOutfeed(device_client.device, shape).ValueOrDie();
ReceiveRawFromOutfeed(device, shape).ValueOrDie();
received->SetLiteral(std::move(data));
absl::MutexLock lock(&mu_);
EnqueueReceivedData(std::move(received));
@ -392,15 +383,14 @@ void OutfeedReceiverImpl::CallbackThreadLoop() {
}
{
tensorflow::profiler::TraceMe traceme("OutfeedReceiver::Callback");
DeviceWithClient device_client = received->device_client();
callback_(device_client.device, std::move(device_client.client),
received->consumer_id(), received->literal());
callback_(received->device(), received->consumer_id(),
received->literal());
}
}
}
Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
const Device* device = devices_[device_idx].device;
const Device* device = devices_[device_idx];
constexpr int consumer_id = kOutfeedCidShutdown;
VLOG(2) << "[" << device->DebugString()
<< "] SendSpecialHeader cons=" << consumer_id;
@ -421,7 +411,7 @@ Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtExecutable> executable,
PjRtExecutable::Compile(computation, devices_[device_idx].client.get(),
PjRtExecutable::Compile(computation, devices_[device_idx]->client(),
std::move(compile_options)));
ExecuteOptions execute_options;
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
@ -468,11 +458,11 @@ StatusOr<XlaOp> OutfeedReceiverImpl::AddOutfeedToBuilder(
return token;
}
OutfeedReceiver::OutfeedReceiver(
Callback callback, std::vector<std::shared_ptr<PjRtClient>> clients,
ssize_t max_callback_queue_size_bytes) {
OutfeedReceiver::OutfeedReceiver(Callback callback,
absl::Span<PjRtClient* const> clients,
ssize_t max_callback_queue_size_bytes) {
p_impl_ = absl::make_unique<OutfeedReceiverImpl>(
callback, std::move(clients), max_callback_queue_size_bytes);
callback, clients, max_callback_queue_size_bytes);
}
OutfeedReceiver::~OutfeedReceiver() {}

View File

@ -31,10 +31,9 @@ class OutfeedReceiverImpl;
// Implements a multithreaded receiver of outfeeds from devices.
class OutfeedReceiver {
public:
// A callback takes: device, client (for the device), consumer id, received.
// The client pointer should be alive while the device is used.
using Callback = std::function<void(Device*, std::shared_ptr<PjRtClient>,
uint32_t, std::shared_ptr<Literal>)>;
// A callback takes: device, consumer id, received.
using Callback =
std::function<void(Device*, uint32_t, std::shared_ptr<Literal>)>;
// Constructs the receiver for the given clients and callback function.
//
@ -45,8 +44,7 @@ class OutfeedReceiver {
// max_callback_queue_size_bytes: the maximum number of bytes for all
// received outfeeds queued to be processed. When this limit is reached
// we pause receiving outfeeds from devices.
OutfeedReceiver(Callback callback,
std::vector<std::shared_ptr<PjRtClient>> clients,
OutfeedReceiver(Callback callback, absl::Span<PjRtClient* const> clients,
ssize_t max_callback_queue_size_bytes);
OutfeedReceiver(const OutfeedReceiver&) = delete;

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
#include "absl/synchronization/mutex.h"
#include "pybind11/functional.h"
@ -24,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/outfeed_receiver.h"
#include "tensorflow/compiler/xla/python/py_client.h"
#include "tensorflow/compiler/xla/python/types.h"
namespace xla {
@ -41,17 +43,22 @@ class OutfeedReceiverForPython {
std::function<void(ClientAndPtr<Device>, uint32_t, pybind11::object)>;
OutfeedReceiverForPython(CallbackToPython callback_python,
std::vector<std::shared_ptr<PjRtClient>> clients,
ssize_t max_callback_queue_size_bytes) {
callback_python_ = callback_python;
outfeed_receiver_shutting_down_ = false;
std::vector<std::shared_ptr<PyClient>> clients,
ssize_t max_callback_queue_size_bytes)
: callback_python_(std::move(callback_python)),
clients_(std::move(clients)) {
OutfeedReceiver::Callback callback =
[this](Device* device, std::shared_ptr<PjRtClient> client,
uint32_t consumer_id, std::shared_ptr<Literal> literal) {
this->Callback(device, client, consumer_id, literal);
[this](Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> literal) {
this->Callback(device, consumer_id, std::move(literal));
};
std::vector<PjRtClient*> client_ptrs(clients.size());
absl::c_transform(clients_, client_ptrs.begin(),
[](const std::shared_ptr<PyClient>& client) {
return client->pjrt_client();
});
outfeed_receiver_ = absl::make_unique<OutfeedReceiver>(
callback, std::move(clients), max_callback_queue_size_bytes);
callback, client_ptrs, max_callback_queue_size_bytes);
}
OutfeedReceiverForPython(const OutfeedReceiverForPython&) = delete;
OutfeedReceiverForPython& operator=(const OutfeedReceiverForPython&) = delete;
@ -79,8 +86,8 @@ class OutfeedReceiverForPython {
arrays);
}
void Callback(Device* device, std::shared_ptr<PjRtClient> client,
uint32_t consumer_id, std::shared_ptr<Literal> literal) {
void Callback(Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> literal) {
{
absl::MutexLock lock(&mu_);
if (outfeed_receiver_shutting_down_) {
@ -88,19 +95,26 @@ class OutfeedReceiverForPython {
return;
}
}
// We expect the number of clients to be small, so an O(n) search is fine.
auto it = absl::c_find_if(
clients_, [device](const std::shared_ptr<PyClient>& client) {
return client->pjrt_client() == device->client();
});
CHECK(it != clients_.end());
py::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython
py::object literal_python =
LiteralToPython(std::move(literal)).ValueOrDie();
// The callback_ should handle all exceptions in user-code. If we get
// an exception here, it is a bug in the callback and we should stop.
callback_python_(WrapWithClient<Device>(std::move(client), device),
consumer_id, std::move(literal_python));
callback_python_(WrapWithClient<Device>(*it, device), consumer_id,
std::move(literal_python));
}
private:
CallbackToPython callback_python_;
absl::Mutex mu_;
bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_);
bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_) = false;
std::vector<std::shared_ptr<PyClient>> clients_;
std::unique_ptr<OutfeedReceiver> outfeed_receiver_;
};
@ -112,7 +126,7 @@ void BuildOutfeedReceiverSubmodule(py::module* m) {
outfeed_receiver.def(
"start",
[](OutfeedReceiverForPython::CallbackToPython callback_to_python,
std::vector<std::shared_ptr<PjRtClient>> clients,
std::vector<std::shared_ptr<PyClient>> clients,
ssize_t max_callback_queue_size_bytes)
-> std::unique_ptr<OutfeedReceiverForPython> {
auto server = absl::make_unique<OutfeedReceiverForPython>(

View File

@ -75,14 +75,14 @@ class Accumulator {
TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) {
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
GetCpuClient(true));
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback =
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
uint32_t consumer_id, std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();
@ -108,14 +108,14 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) {
TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) {
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
GetCpuClient(true));
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback =
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
uint32_t consumer_id, std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();
@ -153,14 +153,14 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) {
TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) {
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
GetCpuClient(true));
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback =
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
uint32_t consumer_id, std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();
@ -196,14 +196,14 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) {
TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) {
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
GetCpuClient(true));
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback =
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
uint32_t consumer_id, std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();
@ -230,14 +230,14 @@ TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) {
TEST(OutfeedReceiverTest, InvalidConsumerIdError) {
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
GetCpuClient(true));
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback =
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
uint32_t consumer_id, std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();

View File

@ -0,0 +1,218 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/py_buffer.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
#include "tensorflow/compiler/xla/python/types.h"
namespace xla {
namespace py = pybind11;
PyBuffer::PyBuffer(std::shared_ptr<PyClient> client,
std::unique_ptr<PjRtBuffer> buffer,
std::unique_ptr<Traceback> traceback)
: client_(std::move(client)),
buffer_(std::move(buffer)),
traceback_(std::move(traceback)) {}
ClientAndPtr<Device> PyBuffer::device() const {
return WrapWithClient(client_, buffer_->device());
}
StatusOr<std::unique_ptr<PyBuffer>> PyBuffer::CopyToDevice(
const ClientAndPtr<Device>& dst_device) const {
CHECK(dst_device.get() != nullptr);
GlobalPyRefManager()->CollectGarbage();
auto traceback = Traceback::Get();
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> out,
buffer_->CopyToDevice(dst_device.get()));
return std::make_unique<PyBuffer>(dst_device.client, std::move(out),
std::move(traceback));
}
Status PyBuffer::BlockHostUntilReady() {
GlobalPyRefManager()->CollectGarbage();
py::gil_scoped_release gil_release;
return buffer_->BlockHostUntilReady();
}
StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer_->AsShapedBuffer());
if (shaped_buffer.on_device_shape().IsTuple()) {
return Unimplemented(
"unsafe_buffer_pointer is not implemented for tuple "
"buffers.");
}
return absl::bit_cast<std::uintptr_t>(shaped_buffer.root_buffer().opaque());
}
StatusOr<py::dict> PyBuffer::CudaArrayInterface() const {
if (buffer_->device()->local_device_state()->executor()->platform_kind() !=
se::PlatformKind::kCuda) {
return InvalidArgument(
"__cuda_array_interface__ is only defined for NVidia GPU buffers.");
}
if (!buffer_->on_device_shape().IsArray()) {
return InvalidArgument(
"__cuda_array_interface__ is only defined for array buffers.");
}
if (buffer_->on_host_shape().element_type() == BF16) {
return InvalidArgument(
"__cuda_array_interface__ is not supported for bfloat16 buffers.");
}
TF_RET_CHECK(
LayoutUtil::IsMonotonicWithDim0Major(buffer_->on_host_shape().layout()));
TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer_->AsShapedBuffer());
py::dict result;
result["shape"] = IntSpanToTuple(shaped_buffer.on_host_shape().dimensions());
TF_ASSIGN_OR_RETURN(py::str typestr,
TypeDescriptorForPrimitiveType(
shaped_buffer.on_host_shape().element_type()));
result["typestr"] = std::move(typestr);
py::tuple data(2);
data[0] = py::int_(
absl::bit_cast<std::uintptr_t>(shaped_buffer.root_buffer().opaque()));
data[1] = py::bool_(true); // read-only
result["data"] = std::move(data);
result["version"] = py::int_(2);
return result;
}
// PEP 3118 buffer protocol implementation.
namespace {
// Extra data to be kept alive by the consumer of the buffer protocol.
struct ExtraBufferInfo {
explicit ExtraBufferInfo(PjRtBuffer::ScopedHold device_buffer)
: device_buffer(std::move(device_buffer)) {}
std::string format;
std::vector<Py_ssize_t> strides;
// We keep a reference to the TrackedDeviceBuffer that backs the
// PjRtBuffer. This prevents a use-after-free in the event that Delete() is
// called on a buffer with an live buffer protocol view. It does however mean
// that Delete() sometimes won't actually delete immediately.
PjRtBuffer::ScopedHold device_buffer;
};
int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) {
auto& buffer =
*py::reinterpret_borrow<py::object>(exporter).cast<PyBuffer&>().buffer();
Status status = [&]() {
// Py_buffer objects are POD C structures, so we don't need to hold the GIL.
// Additionally we call BlockHostUntilReady() below, which may block.
py::gil_scoped_release gil_release;
if (buffer.device()->platform_name() != "cpu") {
return InvalidArgument(
"Python buffer protocol is only defined for CPU buffers.");
}
if (!buffer.on_device_shape().IsArray()) {
return InvalidArgument(
"Python buffer protocol is only defined for array buffers.");
}
// If we allowed exports of formatted BF16 buffers, consumers would get
// confused about the type because there is no way to describe BF16 to
// Python.
if (buffer.on_host_shape().element_type() == BF16 &&
((flags & PyBUF_FORMAT) == PyBUF_FORMAT)) {
return InvalidArgument(
"bfloat16 buffer format not supported by Python buffer protocol.");
}
if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) {
return InvalidArgument("XLA buffers are read-only.");
}
PjRtBuffer::ScopedHold device_buffer(
buffer.GetBufferWithExternalReference());
if (!device_buffer.status().ok()) {
return InvalidArgument("Deleted buffer used in buffer protocol.");
}
const Shape& shape = buffer.on_host_shape();
if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS ||
(flags & PyBUF_STRIDES) == PyBUF_ND) &&
!LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
return InvalidArgument("Buffer is not in C-contiguous layout.");
} else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS &&
!LayoutUtil::IsMonotonicWithDim0Minor(shape.layout())) {
return InvalidArgument("Buffer is not in F-contiguous layout.");
} else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS &&
!LayoutUtil::IsMonotonicWithDim0Major(shape.layout()) &&
!LayoutUtil::IsMonotonicWithDim0Minor(shape.layout())) {
return InvalidArgument("Buffer is not in contiguous layout.");
}
std::memset(view, 0, sizeof(Py_buffer));
CHECK_EQ(device_buffer->device_memory().size(), 1);
view->buf =
const_cast<void*>(device_buffer->device_memory().front().opaque());
auto extra = absl::make_unique<ExtraBufferInfo>(std::move(device_buffer));
view->itemsize = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
view->len = ShapeUtil::ByteSizeOf(shape);
view->readonly = 1;
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
TF_ASSIGN_OR_RETURN(extra->format, FormatDescriptorForPrimitiveType(
shape.element_type()));
view->format = const_cast<char*>(extra->format.c_str());
}
if ((flags & PyBUF_ND) == PyBUF_ND) {
view->ndim = shape.dimensions_size();
static_assert(sizeof(int64) == sizeof(Py_ssize_t),
"Py_ssize_t must be 64 bits");
if (view->ndim != 0) {
view->shape = reinterpret_cast<Py_ssize_t*>(
const_cast<int64*>(shape.dimensions().data()));
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
extra->strides = ByteStridesForShape(shape);
view->strides = extra->strides.data();
}
}
}
TF_RETURN_IF_ERROR(buffer.BlockHostUntilReady());
view->internal = extra.release();
return Status::OK();
}();
if (!status.ok()) {
PyErr_SetString(PyExc_BufferError, status.ToString().c_str());
return -1;
}
view->obj = exporter;
Py_INCREF(view->obj);
return 0;
}
void PjRtBufferReleaseBuffer(PyObject*, Py_buffer* buffer) {
auto extra = static_cast<ExtraBufferInfo*>(buffer->internal);
delete extra;
}
PyBufferProcs PjRtBufferProcs = []() {
PyBufferProcs procs;
procs.bf_getbuffer = &PjRtBufferGetBuffer;
procs.bf_releasebuffer = &PjRtBufferReleaseBuffer;
return procs;
}();
} // namespace
/*static*/ PyBufferProcs* PyBuffer::BufferProtocol() {
return &PjRtBufferProcs;
}
} // namespace xla

View File

@ -0,0 +1,73 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_
#include <memory>
#include <vector>
#include "tensorflow/compiler/xla/python/py_client.h"
#include "tensorflow/compiler/xla/python/traceback.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
// Python wrapper around PjRtBuffer. We use a wrapper class:
// a) to keep the PjRtClient alive via a std::shared_ptr<>
// b) to add Python-specific functionality.
class PyBuffer {
public:
PyBuffer(std::shared_ptr<PyClient> client, std::unique_ptr<PjRtBuffer> buffer,
std::unique_ptr<Traceback> traceback);
std::shared_ptr<PyClient> client() const { return client_; }
PjRtBuffer* buffer() const { return buffer_.get(); }
ClientAndPtr<Device> device() const;
const std::string& platform_name() const { return buffer_->platform_name(); }
bool is_deleted() const { return buffer_->IsDeleted(); }
StatusOr<std::unique_ptr<PyBuffer>> CopyToDevice(
const ClientAndPtr<Device>& dst_device) const;
void Delete() { return buffer_->Delete(); }
Status BlockHostUntilReady();
Status CopyToHostAsync() { return buffer_->CopyToHostAsync(); }
const Shape& shape() { return buffer_->on_host_shape(); }
StatusOr<std::uintptr_t> UnsafeBufferPointer() const;
// Implementation of the CUDA array interface for sharing GPU buffers with
// other Python libraries.
StatusOr<pybind11::dict> CudaArrayInterface() const;
// PEP 3118 Python buffer protocol implementation.
static PyBufferProcs* BufferProtocol();
Traceback* traceback() { return traceback_.get(); }
private:
std::shared_ptr<PyClient> client_;
std::unique_ptr<PjRtBuffer> buffer_;
std::unique_ptr<Traceback> traceback_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_

View File

@ -0,0 +1,130 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/py_client.h"
#include "tensorflow/compiler/xla/python/py_buffer.h"
#include "tensorflow/compiler/xla/python/py_executable.h"
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
#include "tensorflow/compiler/xla/python/traceback.h"
#include "tensorflow/compiler/xla/python/types.h"
namespace xla {
namespace py = pybind11;
PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
: pjrt_client_(std::move(pjrt_client)) {}
std::vector<ClientAndPtr<Device>> PyClient::Devices() {
std::vector<ClientAndPtr<Device>> devices;
devices.reserve(pjrt_client_->devices().size());
for (const auto& device : pjrt_client_->devices()) {
devices.push_back(WrapWithClient(shared_from_this(), device.get()));
}
return devices;
}
std::vector<ClientAndPtr<Device>> PyClient::LocalDevices() {
std::vector<ClientAndPtr<Device>> devices;
devices.reserve(pjrt_client_->local_devices().size());
for (Device* device : pjrt_client_->local_devices()) {
devices.push_back(WrapWithClient(shared_from_this(), device));
}
return devices;
}
StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>>
PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment,
pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions));
std::vector<std::vector<ClientAndPtr<Device>>> result;
result.resize(num_replicas);
for (int r = 0; r < num_replicas; ++r) {
result[r].resize(num_partitions);
for (int p = 0; p < num_partitions; ++p) {
int device_id = device_assignment(r, p);
auto iter = pjrt_client_->id_to_device().find(device_id);
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
result[r][p] = WrapWithClient(shared_from_this(), iter->second);
}
}
return result;
}
StatusOr<std::vector<ClientAndPtr<Device>>>
PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
pjrt_client_->GetDefaultDeviceAssignment(
num_replicas, /*num_partitions=*/1));
std::vector<ClientAndPtr<Device>> result;
for (int i = 0; i < num_replicas; ++i) {
int device_id = device_assignment(i, 0);
auto iter = pjrt_client_->id_to_device().find(device_id);
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
result.push_back(WrapWithClient(shared_from_this(), iter->second));
}
return result;
}
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
const pybind11::object& argument, Device* device, bool force_copy) {
if (device == nullptr) {
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
device = pjrt_client_->local_devices().front();
}
CHECK(device != nullptr);
auto iter = pjrt_client_->id_to_device().find(device->id());
if (iter->second != device) {
return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
device->DebugString(),
pjrt_client_->platform_name());
}
GlobalPyRefManager()->CollectGarbage();
absl::optional<CastToArrayResult> c = CastToArray(argument);
if (!c) {
return InvalidArgument("from_python argument must be an array.");
}
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument));
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
GlobalPyRefManager()->ManageReference(std::move(c->array));
auto traceback = Traceback::Get();
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> buffer,
PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
std::move(py_buffer_ref), pjrt_client_.get(),
device));
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
std::move(traceback));
}
StatusOr<std::unique_ptr<PyExecutable>> PyClient::Compile(
const XlaComputation& computation, CompileOptions options) {
auto traceback = Traceback::Get();
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
PjRtExecutable::Compile(computation, pjrt_client_.get(),
std::move(options)));
return std::make_unique<PyExecutable>(
shared_from_this(), std::move(executable), std::move(traceback));
}
} // namespace xla

View File

@ -0,0 +1,136 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
#include <memory>
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "pybind11/pybind11.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
class PyBuffer;
class PyClient;
class PyExecutable;
// Custom holder types.
//
// We must keep the PyClient object alive as long as any of the runtime
// objects are alive. Since we don't have a lot of control over Python
// destructor ordering, we keep the PyClient object as a std::shared_ptr<>,
// and ensure that each Python runtime object holds a reference to the
// PyClient. An alternative design would be to keep a single global
// singleton PyClient, although this seems less flexible, especially for
// writing tests.
//
// To maintain PyClient references, we define pybind11 holder classes that
// are custom smart pointers that also keep a reference to a PyClient.
// pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't
// seem sufficiently flexible to describe ownership relationships in cases where
// the ownership doesn't pertain to a direct argument or return value of a
// function. Another alternative to the holder classes would be to create proxy
// objects that contain both a reference and a runtime class; holder classes
// seem less tedious to define.
// A pair of a PyClient reference and an unowned pointer to T.
template <typename T>
struct ClientAndPtr {
ClientAndPtr() = default;
// pybind11 requires that we define a constructor that takes a raw pointer,
// but it should be unreachable.
explicit ClientAndPtr(T*) {
LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient.";
}
ClientAndPtr(const ClientAndPtr&) = default;
ClientAndPtr(ClientAndPtr&&) = default;
ClientAndPtr& operator=(const ClientAndPtr&) = default;
ClientAndPtr& operator=(ClientAndPtr&&) = default;
std::shared_ptr<PyClient> client;
T* contents;
T* get() const { return contents; }
T* operator->() const { return contents; }
T& operator*() const { return *contents; }
};
// By defining a templated helper function, we can use return type deduction
// and avoid specifying types at the caller.
template <typename T>
ClientAndPtr<T> WrapWithClient(std::shared_ptr<PyClient> client, T* contents) {
ClientAndPtr<T> result;
result.client = std::move(client);
result.contents = contents;
return result;
}
// Python wrapper around PjRtClient.
// We use a wrapper class to add Python-specific functionality.
class PyClient : public std::enable_shared_from_this<PyClient> {
public:
explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client);
PjRtClient* pjrt_client() const { return pjrt_client_.get(); }
const std::string& platform_name() const {
return pjrt_client_->platform_name();
}
int local_device_count() const { return pjrt_client_->local_device_count(); }
int device_count() const { return pjrt_client_->device_count(); }
int host_id() const { return pjrt_client_->host_id(); }
std::vector<ClientAndPtr<Device>> Devices();
std::vector<ClientAndPtr<Device>> LocalDevices();
StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>>
GetDefaultDeviceAssignment(int num_replicas, int num_partitions);
// TODO(skye): delete after all callers can handle 2D output
StatusOr<std::vector<ClientAndPtr<Device>>> GetDefaultDeviceAssignment1D(
int num_replicas);
StatusOr<ChannelHandle> CreateChannelHandle() {
return pjrt_client_->client()->CreateChannelHandle();
}
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
return pjrt_client_->client()->CreateDeviceToHostChannelHandle();
}
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
return pjrt_client_->client()->CreateHostToDeviceChannelHandle();
}
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyal(
const pybind11::object& argument, Device* device, bool force_copy);
StatusOr<std::unique_ptr<PyExecutable>> Compile(
const XlaComputation& computation, CompileOptions options);
private:
std::shared_ptr<PjRtClient> pjrt_client_;
};
} // namespace xla
PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr<T>);
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_

View File

@ -0,0 +1,101 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/py_executable.h"
#include "absl/algorithm/container.h"
namespace xla {
namespace py = pybind11;
PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
std::unique_ptr<PjRtExecutable> executable,
std::unique_ptr<Traceback> traceback)
: client_(std::move(client)),
executable_(std::move(executable)),
traceback_(std::move(traceback)) {}
std::vector<ClientAndPtr<Device>> PyExecutable::LocalDevices() const {
std::vector<ClientAndPtr<Device>> devices;
devices.reserve(executable_->local_devices().size());
for (Device* device : executable_->local_devices()) {
devices.push_back(WrapWithClient(client_, device));
}
return devices;
}
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
absl::Span<PyBuffer* const> args) {
auto traceback = Traceback::Get();
py::gil_scoped_release gil_release;
ExecuteOptions options;
options.untuple_result = true;
std::vector<PjRtBuffer*> arg_buffers(args.size());
absl::c_transform(args, arg_buffers.begin(),
[](PyBuffer* buf) { return buf->buffer(); });
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
executable_->Execute(arg_buffers, options));
std::vector<std::unique_ptr<PyBuffer>> outputs;
outputs.reserve(output_buffers.size());
for (auto& buffer : output_buffers) {
outputs.push_back(std::make_unique<PyBuffer>(client_, std::move(buffer),
std::move(traceback)));
}
return outputs;
}
StatusOr<std::vector<std::vector<std::unique_ptr<PyBuffer>>>>
PyExecutable::ExecuteOnLocalDevices(
absl::Span<const std::vector<PyBuffer*>> args) {
auto traceback = Traceback::Get();
py::gil_scoped_release gil_release;
ExecuteOptions options;
options.untuple_result = true;
std::vector<std::vector<PjRtBuffer*>> arg_buffers(args.size());
for (int computation = 0; computation < args.size(); ++computation) {
arg_buffers[computation].resize(args[computation].size());
absl::c_transform(args[computation], arg_buffers[computation].begin(),
[](PyBuffer* buf) { return buf->buffer(); });
}
TF_ASSIGN_OR_RETURN(
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers,
executable_->ExecuteOnLocalDevices(arg_buffers, options));
std::vector<std::vector<std::unique_ptr<PyBuffer>>> outputs;
outputs.resize(output_buffers.size());
for (int computation = 0; computation < output_buffers.size();
++computation) {
for (auto& buffer : output_buffers[computation]) {
outputs[computation].push_back(std::make_unique<PyBuffer>(
client_, std::move(buffer), std::move(traceback)));
}
}
return outputs;
}
StatusOr<std::vector<std::shared_ptr<HloModule>>> PyExecutable::HloModules()
const {
std::vector<std::shared_ptr<HloModule>> modules;
modules.reserve(executable_->executables().size());
for (const auto& local_exec : executable_->executables()) {
if (!local_exec->executable()->has_module()) {
return InvalidArgument("Executable does not have HLO modules.");
}
modules.push_back(local_exec->executable()->shared_module());
}
return std::move(modules);
}
} // namespace xla

View File

@ -0,0 +1,74 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_EXECUTABLE_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_PY_EXECUTABLE_H_
#include <memory>
#include <utility>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/py_buffer.h"
#include "tensorflow/compiler/xla/python/py_client.h"
#include "tensorflow/compiler/xla/python/traceback.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
// Python wrapper around PjRtExecutable. We use a wrapper class:
// a) to keep the PyClient alive via a std::shared_ptr<>
// b) to add Python-specific functionality.
class PyExecutable {
public:
PyExecutable(std::shared_ptr<PyClient> client,
std::unique_ptr<PjRtExecutable> executable,
std::unique_ptr<Traceback> traceback);
std::shared_ptr<PyClient> client() const { return client_; }
const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
return executable_->local_logical_device_ids();
}
std::vector<ClientAndPtr<Device>> LocalDevices() const;
int64 SizeOfGeneratedCodeInBytes() const {
return executable_->SizeOfGeneratedCodeInBytes();
}
void Delete() { return executable_->Delete(); }
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> Execute(
absl::Span<PyBuffer* const> args);
StatusOr<std::vector<std::vector<std::unique_ptr<PyBuffer>>>>
ExecuteOnLocalDevices(absl::Span<const std::vector<PyBuffer*>> args);
StatusOr<std::vector<std::shared_ptr<HloModule>>> HloModules() const;
Traceback* traceback() { return traceback_.get(); }
private:
std::shared_ptr<PyClient> client_;
std::unique_ptr<PjRtExecutable> executable_;
std::unique_ptr<Traceback> traceback_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_PY_EXECUTABLE_H_

View File

@ -50,6 +50,22 @@ PythonRefManager::ManageReferences(absl::Span<py::object> objects) {
return std::make_shared<ManagedPyObjects>(this, objects);
}
void PythonRefManager::AddGarbage(absl::Span<py::object> garbage) {
absl::MutexLock lock(&mu_);
for (py::object& o : garbage) {
python_garbage_.push_back(std::move(o));
}
}
void PythonRefManager::AddGarbage(
absl::Span<std::pair<PyCodeObject*, int> const> garbage) {
absl::MutexLock lock(&mu_);
for (const auto& o : garbage) {
python_garbage_.push_back(py::reinterpret_steal<py::object>(
reinterpret_cast<PyObject*>(o.first)));
}
}
void PythonRefManager::CollectGarbage() {
// TODO(phawkins): we should CHECK(PyGILState_Check());
std::deque<pybind11::object> garbage;

View File

@ -66,6 +66,10 @@ class PythonRefManager {
std::shared_ptr<ManagedPyObjects> ManageReferences(
absl::Span<pybind11::object> objects);
// Adds garbage objects to the manager.
void AddGarbage(absl::Span<pybind11::object> garbage);
void AddGarbage(absl::Span<std::pair<PyCodeObject*, int> const> garbage);
// Releases the contents of python_garbage_. Requires that the GIL is held.
// The client calls this method during API entry points where the GIL is held
// to free any garbage that has accumulated.

View File

@ -173,9 +173,13 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def("shape", &PyTpuBuffer::on_host_shape)
.def("device", &PyTpuBuffer::device)
.def("platform", &PyTpuBuffer::platform_name)
.def("is_deleted", [](const PyTpuBuffer& buffer) {
return buffer.DeviceBuffer() == nullptr;
});
.def("is_deleted",
[](const PyTpuBuffer& buffer) {
return buffer.DeviceBuffer() == nullptr;
})
// TODO(phawkins): implement traceback support.
.def_property_readonly("traceback",
[](PyTpuBuffer*) { return py::none(); });
py::class_<PyTpuExecutable>(m, "TpuExecutable")
.def("local_logical_device_ids",
@ -193,7 +197,10 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def("execute", &PyTpuExecutable::Execute,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
.def("execute_on_local_devices", &PyTpuExecutable::ExecuteOnLocalDevices,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
// TODO(phawkins): implement traceback support.
.def_property_readonly("traceback",
[](PyTpuExecutable*) { return py::none(); });
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
.def_property_readonly("coords", &TpuDevice::coords)

Some files were not shown because too many files have changed in this diff Show More