diff --git a/.bazelrc b/.bazelrc index 396b84f70b3..03208385283 100644 --- a/.bazelrc +++ b/.bazelrc @@ -49,7 +49,6 @@ # rocm: Build with AMD GPU support (rocm). # mkl: Enable full mkl support. # tensorrt: Enable Tensorrt support. -# ngraph: Enable ngraph support. # numa: Enable numa using hwloc. # noaws: Disable AWS S3 storage support # nogcp: Disable GCS support. @@ -159,6 +158,7 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain # environment variable "TF_MKL_ROOT" every time before build. build:mkl --define=build_with_mkl=true --define=enable_mkl=true build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 +build:mkl --define=build_with_openmp=true build:mkl -c opt # config to build OneDNN backend with a user specified threadpool. @@ -172,6 +172,7 @@ build:mkl_threadpool -c opt build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl_opensource_only --define=build_with_mkl_opensource=true +build:mkl_opensource_only --define=build_with_openmp=true build:mkl_opensource_only -c opt # Config setting to build with oneDNN for Arm. @@ -218,7 +219,6 @@ build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true build:rocm --action_env TF_NEED_ROCM=1 # Options extracted from configure script -build:ngraph --define=with_ngraph_support=true build:numa --define=with_numa_support=true # Options to disable default on features @@ -283,7 +283,7 @@ build:ios --copt=-w build:linux --copt=-w build:linux --host_copt=-w build:macos --copt=-w -build:windows --copt=/w +build:windows --copt=/W0 # Tensorflow uses M_* math constants that only get defined by MSVC headers if # _USE_MATH_DEFINES is defined. @@ -294,9 +294,11 @@ build:windows --host_copt=/D_USE_MATH_DEFINES build:linux --define=PREFIX=/usr build:linux --define=LIBDIR=$(PREFIX)/lib build:linux --define=INCLUDEDIR=$(PREFIX)/include +build:linux --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include build:macos --define=PREFIX=/usr build:macos --define=LIBDIR=$(PREFIX)/lib build:macos --define=INCLUDEDIR=$(PREFIX)/include +build:macos --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include # TF_SYSTEM_LIBS do not work on windows. # By default, build TF in C++ 14 mode. diff --git a/README.md b/README.md index 31888cfbbc6..116477670f6 100644 --- a/README.md +++ b/README.md @@ -103,23 +103,22 @@ open-source software development: ### Official Builds -Build Type | Status | Artifacts ------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- -**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) -**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA -**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) -**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) -**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) -**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) -**Libtensorflow MacOS CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) -**Libtensorflow Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) -**Libtensorflow Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) -**Libtensorflow Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) -**Libtensorflow Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly) - +Build Type | Status | Artifacts +----------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- +**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) +**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA +**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) +**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) +**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) +**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) +**Libtensorflow MacOS CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) ### Community Supported Builds @@ -133,12 +132,20 @@ Build Type **Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/) **Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) **Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/) -**Linux aarch64 CPU** Nightly
Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) -**Linux aarch64 CPU** Stable Release | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) | Release [1.15](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) / [2.x](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show) +**Linux aarch64 CPU** Nightly (Linaro)
Python 3.8 | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-hpc-tensorflow)](https://ci.linaro.org/jenkins/job/ldcg-hpc-tensorflow/) | [Nightly](http://snapshots.linaro.org/hpc/python/tensorflow/latest/) +**Linux aarch64 CPU** Stable Release (Linaro) | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-hpc-tensorflow)](https://ci.linaro.org/jenkins/job/ldcg-hpc-tensorflow/) | Release [1.x & 2.x](http://snapshots.linaro.org/hpc/python/tensorflow/latest/) +**Linux aarch64 CPU** Nightly (OpenLab)
Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) +**Linux aarch64 CPU** Stable Release (OpenLab) | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) | Release [1.15](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) / [2.x](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show) **Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) **Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/) **Red Hat® Enterprise Linux® 7.6 CPU & GPU**
Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/) +### Community Supported Containers + +Container Type | Status | Artifacts +----------------------------------------------------------------- | ------ | --------- +**TensorFlow aarch64 Neoverse-N1 CPU** Stable (Linaro)
Debian | Static | Release [2.3](https://hub.docker.com/r/linaro/tensorflow-arm-neoverse-n1) + ## Resources * [TensorFlow.org](https://www.tensorflow.org) @@ -151,6 +158,7 @@ Build Type * [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187) * [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190) * [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp) +* [TensorFlow Codelabs](https://codelabs.developers.google.com/?cat=TensorFlow) * [TensorFlow Chat Room on StackOverflow (not actively monitored by the TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow) * [TensorFlow Blog](https://blog.tensorflow.org) diff --git a/RELEASE.md b/RELEASE.md index 18649653304..962cc87ae28 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,58 @@ +# Release 2.5.0 + + + +## Breaking Changes + +* +* + +## Known Caveats + +* +* +* + +## Major Features and Improvements + +* +* + +* TPU embedding support + * Added `profile_data_directory` to `EmbeddingConfigSpec` in + `_tpu_estimator_embedding.py`. This allows embedding lookup statistics + gathered at runtime to be used in embedding layer partitioning decisions. + +## Bug Fixes and Other Changes + +* +* +* +* `tf.keras`: + * Improvements to Keras preprocessing layers: + * Discretization combiner implemented, with additional arg `epsilon`. + +* `tf.data`: + * Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used + to control how external state should be handled during dataset + serialization or iterator checkpointing. + +* `tf.lite`: + * NNAPI + * Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API. + * Use `NnApiDelegate()` and related delegate configuration methods + directly. +* TF Core: + * Corrected higher-order gradients of control flow constructs (`tf.cond`, + `tf.while_loop`, and compositions like `tf.foldl`) computed with + `tf.GradientTape` inside a `tf.function`. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +, , , , , + # Release 2.4.0 @@ -6,6 +61,15 @@ * * +* Certain float32 ops run in lower precsion on Ampere based GPUs, including + matmuls and convolutions, due to the use of + [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/). + Specifically, inputs to such ops are rounded from 23 bits of precision to 10 + bits of precision. This is unlikely to cause issues in practice for deep + learning models. In some cases, TensorFloat-32 is also used for complex64 ops. + TensorFloat-32 can be disabled by running + `config.experimental.enable_tensor_float_32_execution(False)`. The "Major + Features and Improvements" section has more details. * The byte layout for string tensors across the C-API has been updated to match TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s. * C-API functions `TF_StringDecode`, `TF_StringEncode`, and @@ -34,6 +98,7 @@ shape assumptions (note that you can pass shapes with `None` entries for axes that are meant to be dynamic). You can also disable the input checking entirely by setting `model.input_spec = None`. +* TF pip packages now use CUDA11 and cuDNN 8.0.2. * XLA:CPU and XLA:GPU devices are no longer registered by default. Use `TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be removed). @@ -46,6 +111,49 @@ * `tf.data.experimental.service.WorkerServer` now takes a config tuple instead of individual arguments. Usages should be updated to `tf.data.experimental.service.WorkerServer(worker_config)`. +* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which + updates the gradient definition for quantization which is outside the range + to be 0. To simulate the V1 the behavior of + tf.quantization.quantize_and_dequantize(...) use + tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...). +* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please + use `tf.data.Dataset.from_tensor_slices` instead. +* `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`, + `tf.distribute.StrategyExtended.batch_reduce_to`, + `tf.distribute.ReplicaContext.all_reduce` are renamed to `options`. + `tf.distribute.experimental.CollectiveHints` is renamed + `tf.distribute.experimental.CommunicationOptions`. + `tf.distribute.experimental.CollectiveCommunication` is renamed + `tf.distribute.experimental.CommunicationImplementation`. +* `tf.keras.mixed_precision.experimental`: + * `AutoCastVariable.dtype` now refers to the actual variable dtype, not the + dtype it will be casted to. + * When mixed precision is enabled, `tf.keras.layers.Embedding` now outputs a + float16 or bfloat16 tensor instead of a float32 tensor. + * The property + `tf.keras.mixed_precision.experimental.LossScaleOptimizer.loss_scale` is now + a tensor, not a `LossScale` object. This means to get a loss scale of a + `LossScaleOptimizer` as a tensor, you must now call `opt.loss_scale` instead + of `opt.loss_scale()`. + * The property `should_cast_variables` has been removed from + `tf.keras.mixed_precision.experimental.Policy` + * When passing a `tf.mixed_precision.experimental.DynamicLossScale` to + `tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the + `DynamicLossScale`'s multiplier must be 2. + * When passing a `tf.mixed_precision.experimental.DynamicLossScale` to + `tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the weights of + the `DynanmicLossScale` are copied into the `LossScaleOptimizer` instead of + being reused. This means modifying the weights of the `DynamicLossScale` + will no longer affect the weights of the LossScaleOptimizer, and vice versa. + * The global policy can no longer be set to a non-floating point policy in + `tf.keras.mixed_precision.experimental.set_policy` + * In `Layer.call`, `AutoCastVariable`s will no longer be casted within + `MirroredStrategy.run` or `ReplicaContext.merge_call`. This is because a + thread local variable is used to determine whether `AutoCastVariable`s are + casted, and those two functions run with a different thread. Note this only + applies if one of these two functions is called within `Layer.call`; if one + of those two functions calls `Layer.call`, `AutoCastVariable`s will still be + casted. ## Known Caveats @@ -57,9 +165,40 @@ * * A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy. * A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models. +* Support for + [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) + on Ampere based GPUs has been added. TensorFloat-32, or TF32 for short, is a + math mode for NVIDIA Ampere GPUs which causes certain float32 ops, such as + matrix multiplications and convolutions, to run much faster on Ampere GPUs but + with reduced precision. This reduced precision has not been found to effect + convergence quality of deep learning models in practice. TensorFloat-32 is + enabled by default, but can be disabled with + `tf.config.experimental.enable_tensor_float_32_execution`. * `tf.distribute`: + * `MultiWorkerMirroredStrategy` is graduated out of experimental. + * Peer failure will no longer cause the cluster to hang. + * Major issues with saving are fixed. + * See [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for a tutorial. * Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental. +* The `tf.keras.mixed_precision` API has been made non-experimental. The major + changes to the new non-experimental API are: + * `tf.keras.mixed_precision.Policy` no longer takes in a + `tf.mixed_precision.experimental.LossScale` in the constructor, and no + longer has a `LossScale` associated with it. Instead, `Model.compile` will + automatically wrap the optimizer with a `LossScaleOptimizer` using dynamic + loss scaling if `Policy.name` is "mixed_float16". + * `tf.keras.mixed_precision.LossScaleOptimizer`'s constructor takes in + different arguments. In particular, it no longer takes in a `LossScale`, and + there is no longer a `LossScale` associated with the `LossScaleOptimizer`. + Instead, `LossScaleOptimizer` directly implements fixed or dynamic loss + scaling. See the documentation of + `tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details on + the differences between the experimental `LossScaleOptimizer` and the new + non-experimental `LossScaleOptimizer`. + * `tf.mixed_precision.experimental.LossScale` and its subclasses are + deprecated, as all of its functionality now exists within + `tf.keras.mixed_precision.LossScaleOptimizer` ## Bug Fixes and Other Changes @@ -109,6 +248,10 @@ ([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212), [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213), [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214)) + * Fixes a segfault in `tf.quantization.quantize_and_dequantize` + ([CVE-2020-15265](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15265)) + * Fixes an undefined behavior float cast causing a crash + ([CVE-2020-15266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15266)) * TF Core: * `tf.types.experimental.TensorLike` is a new `Union` type that can be used as type annotation for variables representing a Tensor or a value @@ -130,6 +273,8 @@ stateful ops. * Added `tf.config.experimental.get_memory_usage` to return total memory usage of the device. + * Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`. + * Improve shape inference of nested function calls by supporting constant folding across Arg nodes which makes more static values available to shape inference functions. * `tf.data`: * tf.data service: * Added new `tf.data.experimental.service.register_dataset` and @@ -174,7 +319,16 @@ how many times the function is called, and independent of global seed settings. * `tf.distribute`: - * + * (Experimental) Parameter server training: + * Replaced the existing + `tf.distribute.experimental.ParameterServerStrategy` symbol with + a new class that is for parameter server training in TF2. Usage with + the old symbol, usually with Estimator, should be replaced with + `tf.compat.v1.distribute.experimental.ParameterServerStrategy`. + * Added `tf.distribute.experimental.coordinator.*` namespace, + including the main API `ClusterCoordinator` for coordinating the + training cluster, the related data structure `RemoteValue` + and `PerWorkerValue`. * `tf.keras`: * Improvements from the functional API refactoring: * Functional model construction does not need to maintain a global @@ -209,21 +363,37 @@ * Improvements to Keras preprocessing layers: * TextVectorization can now accept a vocabulary list or file as an init arg. + * TextVectorization, StringLookup, and IntegerLookup can now accept a + vocabulary file via the `set_vocab_from_file` method. + * Normalization can now accept mean and variance values as init args. * In `Attention` and `AdditiveAttention` layers, the `call()` method now accepts a `return_attention_scores` argument. When set to True, the layer returns the attention scores as an additional output argument. * Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints with the same implementation as their `tf.losses` equivalent. + * For Keras model, the individual call of `Model.evaluate` uses no cached + data for evaluation, while `Model.fit` uses cached data when + `validation_data` arg is provided for better performance. + * Added a `save_traces` argument to `model.save`/ + `tf.keras.models.save_model` which determines whether the SavedModel + format stores the Keras model/layer call functions. The traced functions + allow Keras to revive custom models and layers without the original + class definition, but if this isn't required the tracing can be + disabled with the added option. * `tf.function` / AutoGraph: * Added `experimental_follow_type_hints` argument for `tf.function`. When True, the function may use type annotations to optimize the tracing performance. * Added support for `iter(DistributedDataset)` in AutoGraph `for` loops. - * AutoGraph now allows creating new symbols inside a TensorFLow loop, if + * AutoGraph now allows creating new symbols inside a TensorFlow loop, if the values of these symbols at an iteration does not depend on the previous iteration. These types of loops must run at least one iteration, and will raise a runtime error otherwise. + * Variables contained in `tf.Module`s that are set as attributes of + custom Keras `Layer`s and `Model`s are now tracked in + the properties `layer.trainable_variables` and + `layer.non_trainable_variables`. Example: @@ -254,8 +424,13 @@ * Deprecate `Interpreter::UseNNAPI(bool)` C++ API. * Use `NnApiDelegate()` and related delegate configuration methods directly. + * Deprecate `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API + * Prefer controlling this via delegate options, e.g. + `tflite::StatefulNnApiDelegate::Options::allow_fp16' or + `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`. * `DynamicBuffer::AddJoinedString()` will now add a separator if the first string to be joined is empty. + * Added support for cumulative sum (cumsum), both as builtin op and MLIR conversion. * * `tf.random`: @@ -264,7 +439,7 @@ * Math and Linear Algebra: - * + * Add `tf.math.erfcinv`, the inverse to `tf.math.erfc`. * TPU Enhancements: @@ -310,6 +485,12 @@ didn't have the keys sorted, the keys and values were not being printed in accordance with their correct mapping. +* `TensorRT` + + * We now issue a warning when the `session_config` parameter for the TF1 + converter is used or the `rewrite_config_template` field in the TF2 + converter parameter object is used. + * Other: * We have replaced uses of "whitelist" and "blacklist" with "allowlist" @@ -318,6 +499,8 @@ context. * Add `tf.config.experimental.mlir_bridge_rollout` which will help us rollout the new MLIR TPU bridge. + * Added `tf.experimental.register_filesystem_plugin` to load modular + filesystem plugins from Python * ## Thanks to our Contributors @@ -690,6 +873,7 @@ stjohnso98, , , , , * Add `tf.saved_model.LoadOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/LoadOptions?hl=en) as arg with default value `None` to choose the I/O device for loading models and weights. * Update `tf.saved_model.SaveOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/SaveOptions?hl=en) as arg with default value `None` to choose the I/O device for saving models and weights. * Mutable tables now restore checkpointed values when loaded from SavedModel. + * The user object metadata field in the SavedModel proto has been deprecated as part of the updates to Keras SavedModel. Keras was the only consumer of this field prior to the update. * GPU * TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities. * Remove environmental variable `TF_USE_CUDNN`. @@ -718,6 +902,7 @@ stjohnso98, , , , , * Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses. * Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`. * Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization. + * Add `tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` methods to gather and concatenate `tf.distribute.DistributedValues` across workers and devices. ### `tf.keras`: * Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs. diff --git a/WORKSPACE b/WORKSPACE index fa39cedae9b..9db1d9b80eb 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -113,26 +113,10 @@ http_archive( # Required for dependency @com_github_grpc_grpc load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") - grpc_deps() -load( - "@build_bazel_rules_apple//apple:repositories.bzl", - "apple_rules_dependencies", -) - -apple_rules_dependencies() - -load( - "@build_bazel_apple_support//lib:repositories.bzl", - "apple_support_dependencies", -) - -apple_support_dependencies() - -load("@upb//bazel:repository_defs.bzl", "bazel_version_repository") - -bazel_version_repository(name = "bazel_version") +load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps") +grpc_extra_deps() load("//third_party/googleapis:repository_rules.bzl", "config_googleapis") diff --git a/configure.py b/configure.py index e381c8c20db..2f9902bc93e 100644 --- a/configure.py +++ b/configure.py @@ -1163,12 +1163,9 @@ def set_system_libs_flag(environ_cp): syslibs = ','.join(sorted(syslibs.split())) write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs) - if 'PREFIX' in environ_cp: - write_to_bazelrc('build --define=PREFIX=%s' % environ_cp['PREFIX']) - if 'LIBDIR' in environ_cp: - write_to_bazelrc('build --define=LIBDIR=%s' % environ_cp['LIBDIR']) - if 'INCLUDEDIR' in environ_cp: - write_to_bazelrc('build --define=INCLUDEDIR=%s' % environ_cp['INCLUDEDIR']) + for varname in ('PREFIX', 'LIBDIR', 'INCLUDEDIR', 'PROTOBUF_INCLUDE_PATH'): + if varname in environ_cp: + write_to_bazelrc('build --define=%s=%s' % (varname, environ_cp[varname])) def is_reduced_optimize_huge_functions_available(environ_cp): @@ -1487,7 +1484,6 @@ def main(): config_info_line('mkl', 'Build with MKL support.') config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.') config_info_line('monolithic', 'Config for mostly static monolithic build.') - config_info_line('ngraph', 'Build with Intel nGraph support.') config_info_line('numa', 'Build with NUMA support.') config_info_line( 'dynamic_kernels', diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 15ef7f21ed0..379b483e5d2 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -3,6 +3,7 @@ # learning applications. load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary") load( "//tensorflow/core/platform:build_config.bzl", @@ -22,10 +23,6 @@ load( "//tensorflow/python/tools/api/generator:api_init_files_v1.bzl", "TENSORFLOW_API_INIT_FILES_V1", # @unused ) -load( - "//third_party/ngraph:build_defs.bzl", - "if_ngraph", -) load( "//third_party/mkl:build_defs.bzl", "if_mkl_ml", @@ -238,6 +235,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "linux_mips64", + values = {"cpu": "mips64"}, + visibility = ["//visibility:public"], +) + config_setting( name = "debug", values = { @@ -465,14 +468,6 @@ config_setting( visibility = ["//visibility:public"], ) -# This flag is set from the configure step when the user selects with nGraph option. -# By default it should be false -config_setting( - name = "with_ngraph_support", - values = {"define": "with_ngraph_support=true"}, - visibility = ["//visibility:public"], -) - # This flag specifies whether TensorFlow 2.0 API should be built instead # of 1.* API. Note that TensorFlow 2.0 API is currently under development. config_setting( @@ -563,18 +558,45 @@ selects.config_setting_group( ], ) +# 'enable_registration_v2' opts-in to a different implementation of op and +# kernel registration - REGISTER_OP, REGISTER_KERNEL_BUILDER, etc. +# +# This setting is currently experimental. The 'v2' implementation does _not_ +# correspond to a particular, finalized design; rather, it relates to +# developing one. +# +# The current aim of the 'v2' implementation is to allow 'unused' ops and +# kernels to be discarded by the linker (to the benefit of binary size). +bool_flag( + name = "enable_registration_v2", + build_setting_default = False, + visibility = ["//visibility:public"], +) + +config_setting( + name = "registration_v1", + flag_values = {":enable_registration_v2": "False"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "registration_v2", + flag_values = {":enable_registration_v2": "True"}, + visibility = ["//visibility:public"], +) + # DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST! # Instead, please use public APIs or public build rules TF provides. # If you need functionality that is not exposed, we will work with you to expand our public APIs. package_group( name = "internal", - packages = ["//tensorflow/..."], + packages = [ + "//learning/lib/ami/simple_ml/...", + "//tensorflow/...", + ], ) -package_group( - name = "ndarray_tensor_allow_list", - packages = ["//learning/pathways/..."], -) +package_group(name = "ndarray_tensor_allow_list") # Packages that use private types symbols, until they are exported. # TODO(b/154650521) Remove. @@ -605,7 +627,7 @@ bzl_library( "//tensorflow/core/platform/default:cuda_build_defs_bzl", "//third_party/mkl:build_defs_bzl", "//third_party/mkl_dnn:build_defs_bzl", - "//third_party/ngraph:build_defs_bzl", + "@bazel_skylib//rules:common_settings", "@local_config_cuda//cuda:build_defs_bzl", "@local_config_rocm//rocm:build_defs_bzl", "@local_config_tensorrt//:build_defs_bzl", @@ -706,8 +728,11 @@ tf_cc_shared_object( visibility = ["//visibility:public"], deps = [ "//tensorflow/c/experimental/filesystem:filesystem_interface", + "//tensorflow/c/experimental/stream_executor:stream_executor_hdrs", + "//tensorflow/c:kernels_hdrs", + "//tensorflow/c:ops_hdrs", "//tensorflow/cc/saved_model:loader_lite_impl", - "//tensorflow/core:core_cpu_impl", + "//tensorflow/core/common_runtime:core_cpu_impl", "//tensorflow/core:framework_internal_impl", "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", @@ -809,7 +834,7 @@ tf_cc_shared_object( "//tensorflow/cc:scope", "//tensorflow/cc/profiler", "//tensorflow/core:tensorflow", - ] + if_ngraph(["@ngraph_tf//:ngraph_tf"]), + ], ) # ** Targets for Windows build (start) ** diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 677ab3355ff..3f4d70ed60e 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -202,6 +202,7 @@ tf_cuda_library( ":tf_status", ":tf_tensor", "@com_google_absl//absl/strings", + "//tensorflow/c/experimental/filesystem:modular_filesystem", "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc:gradients", "//tensorflow/cc:ops", @@ -217,6 +218,8 @@ tf_cuda_library( "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/kernels:logging_ops", + "//tensorflow/compiler/mlir/tfr:node_expansion_pass", + "//tensorflow/compiler/mlir/tfr:graph_decompose_pass", ], }), alwayslink = 1, @@ -509,6 +512,18 @@ tf_cuda_library( ], ) +cc_library( + name = "kernels_hdrs", + hdrs = ["kernels.h"], + visibility = ["//tensorflow:internal"], + deps = [ + ":c_api_internal", + ":tf_datatype", + ":tf_status", + ":tf_tensor", + ], +) + tf_cuda_library( name = "kernels", srcs = [ @@ -562,6 +577,16 @@ tf_cuda_library( alwayslink = 1, ) +cc_library( + name = "ops_hdrs", + hdrs = ["ops.h"], + visibility = ["//tensorflow:internal"], + deps = [ + ":tf_datatype", + ":tf_status", + ], +) + # ----------------------------------------------------------------------------- # Tests diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index a03e9227a75..9579efab94d 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/platform.h" // NOLINT #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +#include "tensorflow/c/experimental/filesystem/modular_filesystem.h" #include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope_internal.h" @@ -2606,4 +2607,14 @@ void TF_RegisterLogListener(void (*listener)(const char*)) { #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) } +void TF_RegisterFilesystemPlugin(const char* plugin_filename, + TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "FileSystem plugin functionality is not supported on mobile"); +#else + status->status = tensorflow::RegisterFilesystemPlugin(plugin_filename); +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index db5f8fd68f8..f550b690e27 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1577,6 +1577,13 @@ TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); TF_CAPI_EXPORT extern void TF_RegisterLogListener( void (*listener)(const char*)); +// Register a FileSystem plugin from filename `plugin_filename`. +// +// On success, place OK in status. +// On failure, place an error status in status. +TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin( + const char* plugin_filename, TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 81fb9d1a2b8..0d188aa5ee0 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -561,15 +561,15 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, collective_executor_handle->get()->StartAbort(status->status); } -TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, - const char* task, - TF_Status* status) { +TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth( + TFE_Context* ctx, const char* task, int64_t timeout_in_ms, + TF_Status* status) { tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); auto collective_executor_handle = context->GetCollectiveExecutorHandle(); tensorflow::Notification done; collective_executor_handle->get()->remote_access()->CheckPeerHealth( - task, [&done, status](const Status& s) { + task, timeout_in_ms, [&done, status](const Status& s) { status->status = s; done.Notify(); }); diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index c9c74f4e874..90e074d232f 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -241,9 +241,9 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, // Checks the health of collective ops peers. Explicit health check is needed in // multi worker collective ops to detect failures in the cluster. If a peer is // down, collective ops may hang. -TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, - const char* task, - TF_Status* status); +TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth( + TFE_Context* ctx, const char* task, int64_t timeout_in_ms, + TF_Status* status); // Information about the shape of a Tensor and its type. struct TF_ShapeAndType { diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index b90b2644269..fa0fdbae861 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -10,6 +10,9 @@ load( "tf_cuda_library", ) +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") + # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "filegroup") @@ -94,6 +97,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:remote_device", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime:worker_interface", "//tensorflow/core:gpu_runtime", ] + internal_tfrt_deps(), alwayslink = 1, @@ -106,6 +110,7 @@ filegroup( "abstract_function.h", "abstract_operation.h", "abstract_tensor_handle.h", + "c_api.h", "c_api_experimental.h", "c_api_internal.h", "c_api_unified_experimental.h", @@ -638,6 +643,19 @@ cc_library( ], ) +cc_header_only_library( + name = "tfe_tensorhandle_internal_hdrs_only", + extra_deps = [ + "@com_google_absl//absl/strings", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":tfe_tensorhandle_internal", + ], +) + tf_cuda_library( name = "c_api_test_util", testonly = 1, diff --git a/tensorflow/c/eager/abstract_context.h b/tensorflow/c/eager/abstract_context.h index d31b1e13611..07a78f97bd5 100644 --- a/tensorflow/c/eager/abstract_context.h +++ b/tensorflow/c/eager/abstract_context.h @@ -32,7 +32,7 @@ namespace tensorflow { // environment, a traced representation etc. class AbstractContext { protected: - enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape }; + enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape, kOpHandler }; explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {} virtual ~AbstractContext() {} diff --git a/tensorflow/c/eager/abstract_operation.h b/tensorflow/c/eager/abstract_operation.h index 4c630528f5d..997c8e0e441 100644 --- a/tensorflow/c/eager/abstract_operation.h +++ b/tensorflow/c/eager/abstract_operation.h @@ -30,7 +30,14 @@ namespace tensorflow { // tracing or immediate execution mode. class AbstractOperation { protected: - enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt, kTape }; + enum AbstractOperationKind { + kGraph, + kMlir, + kEager, + kTfrt, + kTape, + kOpHandler + }; explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {} virtual ~AbstractOperation() {} diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 5f388bfe0cd..9c73d1aba8c 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -70,6 +70,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h" #endif // !IS_MOBILE_PLATFORM #include "tensorflow/core/framework/node_def_util.h" @@ -855,41 +856,42 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, #else // !defined(IS_MOBILE_PLATFORM) tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - // TODO(yuefengz): support partially specified `worker_name`. - tensorflow::core::RefCountPtr eager_client; - status->status = context->GetClient(worker_name, &eager_client); - if (!status->status.ok()) { + tensorflow::GrpcServer* grpc_server = + dynamic_cast(context->GetServer()); + if (grpc_server == nullptr) { + status->status = + tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer."); + return false; + } + tensorflow::WorkerInterface* wi = + grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name); + if (wi == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "Unable to find worker interface corresponding to task ", worker_name); return false; } - // Send a rpc request to the worker to check aliveness. - tensorflow::eager::KeepAliveRequest request; - request.set_context_id(context->GetContextId()); - tensorflow::eager::KeepAliveResponse response; - - tensorflow::Status keep_alive_status; + tensorflow::GetStatusRequest request; + tensorflow::GetStatusResponse response; + tensorflow::Status remote_status; tensorflow::Notification done; - eager_client->KeepAliveAsync( - &request, &response, - [&keep_alive_status, &done](const tensorflow::Status& s) { - keep_alive_status = s; - done.Notify(); - }); + wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true, + [&remote_status, &done](const tensorflow::Status& s) { + remote_status = s; + done.Notify(); + }); done.WaitForNotification(); + // We set OK status so the call does not raise any exceptions. Instead, caller + // users the return value to tell if the remote worker is alive. status->status = tensorflow::Status::OK(); - // If `context_id` doesn't exist on the remote worker, an InvalidArgument - // error will return. But this still indicates that the remote worker is - // alive. - if (keep_alive_status.ok() || - keep_alive_status.code() == tensorflow::error::INVALID_ARGUMENT) { + if (remote_status.ok()) { return true; - } else { - LOG(INFO) << "Remote worker " << worker_name - << " is not alive: " << keep_alive_status.error_message(); - return false; } + LOG(INFO) << "Remote worker " << worker_name + << " is not alive: " << remote_status.error_message(); + return false; #endif // !IS_MOBILE_PLATFORM } @@ -1445,13 +1447,11 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - status->status = context->Executor().WaitForAllPendingNodes(); + auto* context = tensorflow::unwrap(ctx); + status->status = context->AsyncWait(); if (!status->status.ok()) return; - tensorflow::mutex_lock ml(*context->MetadataMu()); - status->status = MessageToBuffer(*context->RunMetadataProto(), buf); - context->ClearRunMetadata(); + auto run_metadata = context->ExportRunMetadata(); + status->status = MessageToBuffer(*run_metadata, buf); } namespace { diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index cc2270755bf..1ef536a66f6 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -638,3 +638,19 @@ void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, TF_Status* status) { tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable); } + +const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr) { + status->status = tensorflow::errors::InvalidArgument("Invalid handle"); + return nullptr; + } + return tensorflow::unwrap(h)->DeviceType(&status->status); +} + +int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr) { + status->status = tensorflow::errors::InvalidArgument("Invalid handle"); + return -1; + } + return tensorflow::unwrap(h)->DeviceId(&status->status); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 12546c6082a..d0739a5437d 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -553,6 +553,14 @@ TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, TF_Status* status); +// Returns the device type of the operation that produced `h`. +TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType( + TFE_TensorHandle* h, TF_Status* status); + +// Returns the device ID of the operation that produced `h`. +TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index 4975d303375..4fe83b5116d 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -411,5 +411,109 @@ TEST(CAPI, TensorHandleOnDeviceMemory) { TF_DeleteStatus(status); } +TEST(CAPI, TensorHandleNullptr) { + TFE_TensorHandle* h = nullptr; + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + const char* device_type = TFE_TensorHandleDeviceType(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(device_type, nullptr); + ASSERT_EQ("Invalid handle", string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + int device_id = TFE_TensorHandleDeviceID(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(device_id, -1); + ASSERT_EQ("Invalid handle", string(TF_Message(status.get()))); +} + +TEST(CAPI, TensorHandleDevices) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx); + const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type; + int device_id = TFE_TensorHandleDeviceID(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_EQ(0, device_id) << device_id; + + // Disable the test if no GPU is present. + string gpu_device_name; + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { + TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( + hcpu, ctx, gpu_device_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_Op* shape_op = ShapeOp(ctx, hgpu); + TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + device_type = TFE_TensorHandleDeviceType(retvals[0], status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type; + + device_id = TFE_TensorHandleDeviceID(retvals[0], status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_EQ(0, device_id) << device_id; + + TFE_DeleteOp(shape_op); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(hgpu); + } + + TFE_DeleteTensorHandle(hcpu); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteExecutor(executor); + TFE_DeleteContext(ctx); +} + +TEST(CAPI, TensorHandleDefaults) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx); + const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type; + int device_id = TFE_TensorHandleDeviceID(h_default, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_EQ(0, device_id) << device_id; + + TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice( + h_default, ctx, "/device:CPU:0", status.get()); + const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu; + int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_EQ(0, device_id_cpu) << device_id_cpu; + + TFE_DeleteTensorHandle(h_default); + TFE_DeleteTensorHandle(h_cpu); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteExecutor(executor); + TFE_DeleteContext(ctx); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index fd208c6770d..0f5f494e5e2 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -769,7 +769,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); EXPECT_NE(TF_OK, TF_GetCode(status)); EXPECT_EQ(nullptr, t); - const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]"; + const char* msg = "In[0] mismatch In[1] shape: 2 vs. 3: [2,2] [3,2]"; EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr) << TF_Message(status); // Since error is not cleared, the following copy with correct device will diff --git a/tensorflow/c/eager/gradient_checker_test.cc b/tensorflow/c/eager/gradient_checker_test.cc index 7a438085fb5..393ad2ceb98 100644 --- a/tensorflow/c/eager/gradient_checker_test.cc +++ b/tensorflow/c/eager/gradient_checker_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -56,6 +57,9 @@ Status RegisterGradients(GradientRegistry* registry) { } TEST_P(GradientCheckerTest, TestGradCheckMatMul) { + // Computing numerical gradients with TensorFloat-32 is numerically unstable + enable_tensor_float_32_execution(false); + std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); AbstractContextPtr ctx; diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index cd4febba8c1..a81d7aa6952 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -62,10 +62,12 @@ Status RegisterGradients(GradientRegistry* registry) { TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Mul", MulRegisterer)); return Status::OK(); } - // Computes // y = inputs[0] + inputs[1] // return grad(y, {inputs[0], inputs[1]}) @@ -74,11 +76,11 @@ Status AddGradModel(AbstractContext* ctx, absl::Span outputs, const GradientRegistry& registry) { TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); + auto tape = std::make_unique(/*persistent=*/false); tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[1])); // Watch y. std::vector add_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add")); // Compute x+y. @@ -97,7 +99,6 @@ Status AddGradModel(AbstractContext* ctx, } outputs[0] = out_grads[0]; outputs[1] = out_grads[1]; - delete tape; return Status::OK(); } @@ -109,10 +110,10 @@ Status ExpGradModel(AbstractContext* ctx, absl::Span outputs, const GradientRegistry& registry) { TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); + auto tape = std::make_unique(/*persistent=*/false); tape->Watch(ToId(inputs[0])); // Watch x. std::vector exp_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); TF_RETURN_IF_ERROR( ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp")); std::unordered_map @@ -128,7 +129,6 @@ Status ExpGradModel(AbstractContext* ctx, exp_output->Unref(); } outputs[0] = out_grads[0]; - delete tape; return Status::OK(); } @@ -140,10 +140,10 @@ Status SqrtGradModel(AbstractContext* ctx, absl::Span outputs, const GradientRegistry& registry) { TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); + auto tape = std::make_unique(/*persistent=*/false); tape->Watch(ToId(inputs[0])); // Watch x. std::vector sqrt_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); TF_RETURN_IF_ERROR( ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt")); std::unordered_map @@ -159,7 +159,6 @@ Status SqrtGradModel(AbstractContext* ctx, sqrt_output->Unref(); } outputs[0] = out_grads[0]; - delete tape; return Status::OK(); } @@ -172,12 +171,12 @@ Status IdentityNGradModel(AbstractContext* ctx, absl::Span outputs, const GradientRegistry& registry) { TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); + auto tape = std::make_unique(/*persistent=*/false); tape->Watch(ToId(inputs[0])); tape->Watch(ToId(inputs[1])); vector identity_n_outputs(2); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); TF_RETURN_IF_ERROR(ops::IdentityN( tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN")); @@ -195,6 +194,105 @@ Status IdentityNGradModel(AbstractContext* ctx, } outputs[0] = out_grads[0]; outputs[1] = out_grads[1]; + return Status::OK(); +} + +// Computes +// y = - inputs[0] +// return grad(y, {inputs[0]}) +Status NegGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = std::make_unique(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); + + std::vector neg_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); + TF_RETURN_IF_ERROR( + ops::Neg(tape_ctx.get(), inputs, absl::MakeSpan(neg_outputs), "Neg")); + + std::unordered_map + source_tensors_that_are_targets; + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(neg_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto neg_output : neg_outputs) { + neg_output->Unref(); + } + outputs[0] = out_grads[0]; + return Status::OK(); +} + +// Computes +// y = inputs[0] - inputs[1] +// return grad(y, {inputs[0], inputs[1]}) +Status SubGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = std::make_unique(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + tape->Watch(ToId(inputs[1])); // Watch y. + std::vector sub_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); + TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs, + absl::MakeSpan(sub_outputs), + "Sub")); // Compute x-y. + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto sub_output : sub_outputs) { + sub_output->Unref(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + return Status::OK(); +} + +// Computes +// y = inputs[0] * inputs[1] +// return grad(y, {inputs[0], inputs[1]}) +Status MulGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + tape->Watch(ToId(inputs[1])); // Watch y. + std::vector mul_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), inputs, + absl::MakeSpan(mul_outputs), + "Mul")); // Compute x*y. + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(mul_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto mul_output : mul_outputs) { + mul_output->Unref(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; delete tape; return Status::OK(); } @@ -536,6 +634,172 @@ TEST_P(CppGradients, TestIdentityNGrad) { result_tensor = nullptr; } +TEST_P(CppGradients, TestNegGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Pseudo-code: + // + // tape.watch(x) + // y = - x + // outputs = tape.gradient(y, x) + std::vector outputs(1); + s = RunModel(NegGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* result_tensor; + s = getValue(outputs[0], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, -1.0); + outputs[0]->Unref(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; +} + +TEST_P(CppGradients, TestSubGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + AbstractTensorHandlePtr y; + { + AbstractTensorHandle* y_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + y.reset(y_raw); + } + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Pseudo-code: + // + // tape.watch(x) + // tape.watch(y) + // y = x - y + // outputs = tape.gradient(y, [x, y]) + std::vector outputs(2); + s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* result_tensor; + s = getValue(outputs[0], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, 1.0); + outputs[0]->Unref(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; + + s = getValue(outputs[1], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, -1.0); + outputs[1]->Unref(); + TF_DeleteTensor(result_tensor); +} + +TEST_P(CppGradients, TestMulGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + AbstractTensorHandlePtr y; + { + AbstractTensorHandle* y_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + y.reset(y_raw); + } + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Pseudo-code: + // + // tape.watch(x) + // tape.watch(y) + // y = x * y + // outputs = tape.gradient(y, [x, y]) + std::vector outputs(2); + s = RunModel(MulGradModel, ctx.get(), {x.get(), y.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* result_tensor; + s = getValue(outputs[0], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, 2.0); + outputs[0]->Unref(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; + + s = getValue(outputs[1], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, 1.0); + outputs[1]->Unref(); + TF_DeleteTensor(result_tensor); +} + TEST_P(CppGradients, TestSetAttrString) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -575,7 +839,7 @@ TEST_P(CppGradients, TestSetAttrString) { int num_retvals = 1; std::vector outputs(1); GradientRegistry registry; - std::unique_ptr tape(new Tape(/*persistent=*/false)); + auto tape = std::make_unique(/*persistent=*/false); s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs), &num_retvals, &forward_op, tape.get(), registry); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index a3e3857b34b..27fa17127b8 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -124,6 +125,13 @@ class ImmediateExecutionContext : public AbstractContext { // Returns the device placement policy for the current thread. virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0; + // Configure graph collection in RunMetadata. + virtual void SetShouldStoreGraphs(bool value) = 0; + + // Return the collected RunMetadata. This method will transfer the ownership + // to the caller. + virtual std::unique_ptr ExportRunMetadata() = 0; + // For LLVM style RTTI. static bool classof(const AbstractContext* ptr) { return ptr->getKind() == kEager || ptr->getKind() == kTfrt; @@ -149,9 +157,6 @@ class ImmediateExecutionContext : public AbstractContext { // Update the Eager Executor for current thread. virtual void SetExecutorForThread(EagerExecutor* executor) = 0; - // Configure graph collection in RunMetadata. - virtual void SetShouldStoreGraphs(bool value) = 0; - protected: explicit ImmediateExecutionContext(AbstractContextKind kind) : AbstractContext(kind) {} diff --git a/tensorflow/c/eager/immediate_execution_tensor_handle.h b/tensorflow/c/eager/immediate_execution_tensor_handle.h index 6d32d482747..bb6d471f12f 100644 --- a/tensorflow/c/eager/immediate_execution_tensor_handle.h +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.h @@ -44,6 +44,10 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle { virtual const char* DeviceName(Status* status) const = 0; // Returns the device where the tensor was placed. virtual const char* BackingDeviceName(Status* status) const = 0; + // Returns the device type which created the handle. + virtual const char* DeviceType(Status* status) const = 0; + // Returns the device ID which created the handle. + virtual int DeviceId(Status* status) const = 0; // Returns a tensor for the handle. If tensor is remote, it will be copied. virtual AbstractTensorInterface* Resolve(Status* status) = 0; diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc index 4114f50a798..16cb01110fd 100644 --- a/tensorflow/c/eager/mnist_gradients_test.cc +++ b/tensorflow/c/eager/mnist_gradients_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -43,6 +44,11 @@ class CppGradients TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); Status s = StatusFromTF_Status(status.get()); CHECK_EQ(errors::OK, s.code()) << s.error_message(); + + // Computing numerical gradients with TensorFloat-32 is numerically + // unstable. Some forward pass tests also fail with TensorFloat-32 due to + // low tolerances + enable_tensor_float_32_execution(false); } }; diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc index e270bfcbb80..095f33ff303 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc @@ -58,7 +58,7 @@ using ExecutorPtr = std::unique_ptr; class DeviceThread { public: // Starts a background thread waiting for `StartExecute`. - explicit DeviceThread(const std::string& device) + explicit DeviceThread(const std::string& device, const bool is_async) : status_(TF_NewStatus()), device_(device), // If the context's default exector is set to async, re-using that in @@ -67,7 +67,7 @@ class DeviceThread { // // TODO(allenl): We should have an async API that works with the // parallel device. - executor_(TFE_NewExecutor(/*is_async=*/false)), + executor_(TFE_NewExecutor(is_async)), op_(nullptr), thread_(tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "parallel_device_execute", @@ -236,12 +236,13 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name, } } -ParallelDevice::ParallelDevice(const std::vector& devices) +ParallelDevice::ParallelDevice(const std::vector& devices, + const bool is_async) : underlying_devices_(devices) { device_threads_.reserve(devices.size()); for (int device_index = 0; device_index < devices.size(); ++device_index) { device_threads_.emplace_back( - new DeviceThread(devices[device_index].c_str())); + new DeviceThread(devices[device_index].c_str(), is_async)); } } diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/tensorflow/c/eager/parallel_device/parallel_device_lib.h index b3dc47ab088..1bb9ce0f663 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.h +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.h @@ -49,7 +49,10 @@ class DeviceThread; // placed on each underlying device. class ParallelDevice { public: - explicit ParallelDevice(const std::vector& devices); + // Eager async execution is only supported when remote eager is not in use + // (b/157523095). + explicit ParallelDevice(const std::vector& devices, + const bool is_async = false); ~ParallelDevice(); diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc index 5ff28e4229a..50a9f54cb1e 100644 --- a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc @@ -182,9 +182,8 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file, ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); std::string cacheKey(scheme); - hdfsBuilder* builder = libhdfs->hdfsNewBuilder(); if (scheme == "file") { - libhdfs->hdfsBuilderSetNameNode(builder, nullptr); + namenode = ""; } else if (scheme == "viewfs") { char* defaultFS = nullptr; libhdfs->hdfsConfGetStr("fs.defaultFS", &defaultFS); @@ -200,24 +199,27 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file, // The default NameNode configuration will be used (from the XML // configuration files). See: // https://github.com/tensorflow/tensorflow/blob/v1.0.0/third_party/hadoop/hdfs.h#L259 - libhdfs->hdfsBuilderSetNameNode(builder, "default"); + namenode = "default"; } else if (scheme == "har") { std::string path_har = path; SplitArchiveNameAndPath(&path_har, &namenode, status); if (TF_GetCode(status) != TF_OK) return nullptr; - libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str()); - cacheKey += namenode; } else { - libhdfs->hdfsBuilderSetNameNode( - builder, namenode.empty() ? "default" : namenode.c_str()); - cacheKey += namenode; + if (namenode.empty()) { + namenode = "default"; + } } + cacheKey += namenode; + absl::MutexLock l(&hadoop_file->connection_cache_lock); if (hadoop_file->connection_cache.find(cacheKey) == hadoop_file->connection_cache.end()) { + hdfsBuilder* builder = libhdfs->hdfsNewBuilder(); + libhdfs->hdfsBuilderSetNameNode( + builder, namenode.empty() ? nullptr : namenode.c_str()); auto cacheFs = libhdfs->hdfsBuilderConnect(builder); if (cacheFs == nullptr) { - TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno)); + TF_SetStatusFromIOError(status, TF_ABORTED, strerror(errno)); return cacheFs; } hadoop_file->connection_cache[cacheKey] = cacheFs; diff --git a/tensorflow/c/experimental/gradients/math_grad.cc b/tensorflow/c/experimental/gradients/math_grad.cc index 5cba7b28fda..3ee5294ca15 100644 --- a/tensorflow/c/experimental/gradients/math_grad.cc +++ b/tensorflow/c/experimental/gradients/math_grad.cc @@ -24,6 +24,7 @@ using std::vector; using tensorflow::ops::Conj; using tensorflow::ops::MatMul; using tensorflow::ops::Mul; +using tensorflow::ops::Neg; using tensorflow::ops::SqrtGrad; namespace tensorflow { @@ -201,6 +202,93 @@ class MatMulGradientFunction : public GradientFunction { AttrBuilder forward_attrs; }; +class NegGradientFunction : public GradientFunction { + public: + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + /* Given upstream grad U and a Neg op Y = -X, the gradients are: + * + * dX = -U + * + */ + + grad_outputs->resize(1); + std::string name = "Neg_Grad"; + TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]}, + absl::MakeSpan(*grad_outputs), name.c_str())); + return Status::OK(); + } + ~NegGradientFunction() override {} +}; + +class SubGradientFunction : public GradientFunction { + public: + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + /* Given upstream grad U and a Sub op A-B, the gradients are: + * + * dA = U + * dB = -U + * + */ + + grad_outputs->resize(2); + + // Grad for A + DCHECK(grad_inputs[0]); + (*grad_outputs)[0] = grad_inputs[0]; + (*grad_outputs)[0]->Ref(); + + // Grad for B + // negate the upstream grad + std::vector neg_outputs(1); + std::string name = "Neg_Sub_Grad_B"; + TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]}, + absl::MakeSpan(neg_outputs), name.c_str())); + (*grad_outputs)[1] = neg_outputs[0]; + + return Status::OK(); + } + ~SubGradientFunction() override {} +}; + +class MulGradientFunction : public GradientFunction { + public: + explicit MulGradientFunction(vector f_inputs) + : forward_inputs(f_inputs) {} + + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + /* Given upstream grad U and a mul op A*B, the gradients are: + * + * dA = U * B + * dB = A * U + * + */ + + AbstractTensorHandle* upstream_grad = grad_inputs[0]; + grad_outputs->resize(2); + std::vector mul_outputs(1); + + // Gradient for A + std::string name = "Mul_Grad_A"; + TF_RETURN_IF_ERROR(Mul(ctx->ctx, {upstream_grad, forward_inputs[1]}, + absl::MakeSpan(mul_outputs), name.c_str())); + (*grad_outputs)[0] = mul_outputs[0]; + + // Gradient for B + name = "Mul_Grad_B"; + TF_RETURN_IF_ERROR(Mul(ctx->ctx, {forward_inputs[0], upstream_grad}, + absl::MakeSpan(mul_outputs), name.c_str())); + (*grad_outputs)[1] = mul_outputs[0]; + return Status::OK(); + } + ~MulGradientFunction() override {} + + private: + vector forward_inputs; +}; + } // namespace BackwardFunction* AddRegisterer(const ForwardOperation& op) { @@ -239,5 +327,32 @@ BackwardFunction* SqrtRegisterer(const ForwardOperation& op) { return new BackwardFunction(gradient_function, default_gradients); } +BackwardFunction* NegRegisterer(const ForwardOperation& op) { + auto gradient_function = new NegGradientFunction; + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + +BackwardFunction* SubRegisterer(const ForwardOperation& op) { + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto gradient_function = new SubGradientFunction; + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + +BackwardFunction* MulRegisterer(const ForwardOperation& op) { + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto gradient_function = new MulGradientFunction(op.inputs); + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/math_grad.h b/tensorflow/c/experimental/gradients/math_grad.h index 7faeadcca81..d2a0bf2b646 100644 --- a/tensorflow/c/experimental/gradients/math_grad.h +++ b/tensorflow/c/experimental/gradients/math_grad.h @@ -24,6 +24,9 @@ BackwardFunction* AddRegisterer(const ForwardOperation& op); BackwardFunction* ExpRegisterer(const ForwardOperation& op); BackwardFunction* MatMulRegisterer(const ForwardOperation& op); BackwardFunction* SqrtRegisterer(const ForwardOperation& op); +BackwardFunction* NegRegisterer(const ForwardOperation& op); +BackwardFunction* SubRegisterer(const ForwardOperation& op); +BackwardFunction* MulRegisterer(const ForwardOperation& op); } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/tape/tape_operation.cc b/tensorflow/c/experimental/gradients/tape/tape_operation.cc index 0b247d08f6c..841782aa6da 100644 --- a/tensorflow/c/experimental/gradients/tape/tape_operation.cc +++ b/tensorflow/c/experimental/gradients/tape/tape_operation.cc @@ -25,7 +25,7 @@ TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape, parent_op_(parent_op), tape_(tape), registry_(registry) { - // TODO(srbs): Make AbstractOperation RefCounted. + // TODO(b/172003047): Consider making AbstractOperation RefCounted. // parent_op_->Ref(); } void TapeOperation::Release() { @@ -33,7 +33,7 @@ void TapeOperation::Release() { delete this; } TapeOperation::~TapeOperation() { - // TODO(srbs): Make AbstractOperation RefCounted. + // TODO(b/172003047): Consider making AbstractOperation RefCounted. // parent_op->Unref(); } Status TapeOperation::Reset(const char* op, const char* raw_device_name) { diff --git a/tensorflow/c/experimental/op_handler/BUILD b/tensorflow/c/experimental/op_handler/BUILD new file mode 100644 index 00000000000..bdb5328180c --- /dev/null +++ b/tensorflow/c/experimental/op_handler/BUILD @@ -0,0 +1,43 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + licenses = ["notice"], # Apache 2.0 +) + +tf_cc_test( + name = "internal_test", + srcs = ["internal_test.cc"], + deps = [ + ":internal", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:errors", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "internal", + srcs = ["internal.cc"], + hdrs = ["internal.h"], + deps = [ + ":wrapper_operation", + "//tensorflow/c:conversion_macros", + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/core/platform:refcount", + "//tensorflow/core/platform:types", + ], +) + +cc_library( + name = "wrapper_operation", + srcs = ["wrapper_operation.cc"], + hdrs = ["wrapper_operation.h"], + deps = ["//tensorflow/c/eager:abstract_operation"], +) diff --git a/tensorflow/c/experimental/op_handler/internal.cc b/tensorflow/c/experimental/op_handler/internal.cc new file mode 100644 index 00000000000..b9acbf44583 --- /dev/null +++ b/tensorflow/c/experimental/op_handler/internal.cc @@ -0,0 +1,79 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_CC_ +#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_CC_ + +#include "tensorflow/c/experimental/op_handler/internal.h" + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/experimental/op_handler/wrapper_operation.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +OpHandlerContext::OpHandlerContext(AbstractContext* parent_ctx) + : AbstractContext(kOpHandler), parent_ctx_(parent_ctx) {} +OpHandlerContext::~OpHandlerContext() {} +void OpHandlerContext::Release() { delete this; } +Status OpHandlerContext::RegisterFunction(AbstractFunction* function) { + return parent_ctx_->RegisterFunction(function); +} + +Status OpHandlerContext::RemoveFunction(const string& function) { + return parent_ctx_->RemoveFunction(function); +} + +void OpHandlerContext::set_default_handler(OpHandler* handler) { + handler->Ref(); + default_handler_.reset(handler); +} + +OpHandlerOperation* OpHandlerContext::CreateOperation() { + OpHandlerOperation* result = + new OpHandlerOperation(parent_ctx_->CreateOperation()); + if (default_handler_ != nullptr) { + result->set_handler(default_handler_.get()); + } + return result; +} + +OpHandlerOperation::OpHandlerOperation(AbstractOperation* parent_op) + : WrapperOperation(parent_op, kOpHandler) {} + +OpHandler* OpHandlerOperation::get_handler() { return handler_.get(); } + +void OpHandlerOperation::set_handler(OpHandler* handler) { + if (handler != nullptr) { + handler->Ref(); + } + handler_.reset(handler); +} + +Status OpHandlerOperation::Execute(absl::Span retvals, + int* num_retvals) { + if (handler_ == nullptr) { + return WrapperOperation::Execute(retvals, num_retvals); + } else { + return handler_->Execute(this, retvals, num_retvals); + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_ diff --git a/tensorflow/c/experimental/op_handler/internal.h b/tensorflow/c/experimental/op_handler/internal.h new file mode 100644 index 00000000000..de893f77a7e --- /dev/null +++ b/tensorflow/c/experimental/op_handler/internal.h @@ -0,0 +1,117 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/experimental/op_handler/wrapper_operation.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class OpHandlerOperation; + +// Op handlers are a convenient way to intercept and transform computation. +// +// The implementation is currently experimental and incomplete, but aims +// eventually to support tracing and replay of function bodies, gradients +// through copy operations, and a variety of hooks for things like debug +// strings. A public C API for op handlers is planned. +class OpHandler : public core::RefCounted { + public: + // Called on operation->Execute when operation->get_handler() == this. + // + // Allows the handler to customize or inspect `operation`'s execution. + virtual Status Execute(OpHandlerOperation* operation, + absl::Span retvals, + int* num_retvals) = 0; + // Creates a new handler by merging this handler with `next_handler`. + // + // The new handler is expected to transform operations first with this handler + // and then execute the resulting operations on `next_handler` (by calling + // `OpHandlerOperation::set_handler` and passing `next_handler`). If this is + // not possible then the merge operation should fail. + virtual Status Merge(OpHandler* next_handler, + core::RefCountPtr& merged_handler) = 0; +}; + +// Keeps some handler-specific metadata, but otherwise wraps a single +// AbstractOperation in the underlying context. The operation is created, its +// attributes set, etc., and at execution time it is presented to its handler, +// which may choose to execute it or simply inspect it and do something else. +// +// This is somewhat different than the Context approach, where the operation's +// construction is streamed through each layered Context. The streaming approach +// would require a much larger op handler public API, one function pointer per +// attribute type, and there is some ambiguity before an op is finalized about +// whether it should be presented as-is to handlers (regular operations) or +// replayed (function calls and control flow operations). +class OpHandlerOperation : public WrapperOperation { + public: + explicit OpHandlerOperation(AbstractOperation*); + OpHandler* get_handler(); + void set_handler(OpHandler* handler); + Status Execute(absl::Span retvals, + int* num_retvals) override; + + protected: + core::RefCountPtr handler_; +}; + +// A context which allows a default handler to be set for new operations. It +// otherwise defers to the context it wraps. +// +// TODO(allenl): A stack of contexts and a stack of handlers look pretty similar +// in some ways. Having each handler be its own context seems almost doable, +// with things like copy operations and function/control flow replay being +// somewhat tricky (since they should be generated at the top of the handler +// stack and "caught" at the bottom). After handlers have evolved for a bit we +// should re-evaluate whether the handler+context concepts can be merged. +class OpHandlerContext : public AbstractContext { + public: + explicit OpHandlerContext(AbstractContext*); + void Release() override; + OpHandlerOperation* CreateOperation() override; + Status RegisterFunction(AbstractFunction*) override; + Status RemoveFunction(const string&) override; + // For LLVM style RTTI. + static bool classof(const AbstractContext* ptr) { + return ptr->getKind() == kOpHandler; + } + ~OpHandlerContext() override; + + void set_default_handler(OpHandler* handler); + + private: + AbstractContext* parent_ctx_; // Not owned. + core::RefCountPtr default_handler_; +}; + +class ReleaseOpHandlerOperation { + public: + void operator()(OpHandlerOperation* operation) { operation->Release(); } +}; + +typedef std::unique_ptr + OpHandlerOperationPtr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_ diff --git a/tensorflow/c/experimental/op_handler/internal_test.cc b/tensorflow/c/experimental/op_handler/internal_test.cc new file mode 100644 index 00000000000..d8ac8b3b985 --- /dev/null +++ b/tensorflow/c/experimental/op_handler/internal_test.cc @@ -0,0 +1,102 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/op_handler/internal.h" + +#include "absl/types/span.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class TestOpHandler : public OpHandler { + public: + TestOpHandler() : last_operation_(new std::string("")) {} + Status Execute(OpHandlerOperation* operation, + absl::Span retvals, + int* num_retvals) override { + CHECK(operation->get_handler() == this); + *last_operation_ = operation->Name(); + operation->set_handler(next_handler_.get()); + return operation->Execute(retvals, num_retvals); + } + Status Merge(OpHandler* next_handler, + core::RefCountPtr& merged_handler) override { + merged_handler.reset(new TestOpHandler(next_handler, last_operation_)); + return Status::OK(); + } + + core::RefCountPtr next_handler_ = nullptr; + // Shared between merged handlers of this type. + std::shared_ptr last_operation_; + + private: + TestOpHandler(OpHandler* next_handler, + std::shared_ptr last_operation) + : next_handler_(next_handler), last_operation_(last_operation) { + next_handler->Ref(); + } +}; + +TEST(INTERNAL_TEST, UseOpHandler) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr + c_ctx(TF_NewEagerExecutionContext(opts.get(), status.get()), + TF_DeleteExecutionContext); + OpHandlerContext ctx(unwrap(c_ctx.get())); + core::RefCountPtr outer_handler(new TestOpHandler()); + core::RefCountPtr inner_handler(new TestOpHandler()); + ctx.set_default_handler(outer_handler.get()); + OpHandlerOperationPtr op(ctx.CreateOperation()); + Status s = op->Reset("NoOp", ""); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + std::vector retvals; + int num_retvals = 0; + EXPECT_EQ("", *outer_handler->last_operation_); + s = op->Execute(absl::Span(retvals), &num_retvals); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + EXPECT_EQ("NoOp", *outer_handler->last_operation_); + *outer_handler->last_operation_ = ""; + EXPECT_EQ("", *inner_handler->last_operation_); + + // This op executes on both handlers, changing the state of `inner_handler` + // since the handler has decided to preserve that state across merges. + core::RefCountPtr merged; + s = inner_handler->Merge(outer_handler.get(), merged); + ctx.set_default_handler(merged.get()); + op.reset(ctx.CreateOperation()); + s = op->Reset("NoOp", ""); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + s = op->Execute(absl::Span(retvals), &num_retvals); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + EXPECT_EQ("NoOp", *inner_handler->last_operation_); + EXPECT_EQ("NoOp", *outer_handler->last_operation_); + + inner_handler.reset(); + outer_handler.reset(); + op.reset(ctx.CreateOperation()); + s = op->Reset("NoOp", ""); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + s = op->Execute(absl::Span(retvals), &num_retvals); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/op_handler/wrapper_operation.cc b/tensorflow/c/experimental/op_handler/wrapper_operation.cc new file mode 100644 index 00000000000..018bba04b8a --- /dev/null +++ b/tensorflow/c/experimental/op_handler/wrapper_operation.cc @@ -0,0 +1,120 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/op_handler/wrapper_operation.h" + +namespace tensorflow { +WrapperOperation::WrapperOperation(AbstractOperation* parent_op, + AbstractOperationKind kind) + : AbstractOperation(kind), parent_op_(parent_op) { + // TODO(b/172003047): Consider making AbstractOperation RefCounted. + // parent_op_->Ref(); +} +void WrapperOperation::Release() { + parent_op_->Release(); + // TODO(b/172003047): Consider making AbstractOperation RefCounted. + delete this; +} + +Status WrapperOperation::Reset(const char* op, const char* raw_device_name) { + return parent_op_->Reset(op, raw_device_name); +} +const string& WrapperOperation::Name() const { return parent_op_->Name(); } +const string& WrapperOperation::DeviceName() const { + return parent_op_->DeviceName(); +} +Status WrapperOperation::SetDeviceName(const char* name) { + return parent_op_->SetDeviceName(name); +} +Status WrapperOperation::AddInput(AbstractTensorHandle* input) { + return parent_op_->AddInput(input); +} +Status WrapperOperation::AddInputList( + absl::Span inputs) { + return parent_op_->AddInputList(inputs); +} +Status WrapperOperation::SetAttrString(const char* attr_name, const char* data, + size_t length) { + return parent_op_->SetAttrString(attr_name, data, length); +} +Status WrapperOperation::SetAttrInt(const char* attr_name, int64_t value) { + return parent_op_->SetAttrInt(attr_name, value); +} +Status WrapperOperation::SetAttrFloat(const char* attr_name, float value) { + return parent_op_->SetAttrFloat(attr_name, value); +} +Status WrapperOperation::SetAttrBool(const char* attr_name, bool value) { + return parent_op_->SetAttrBool(attr_name, value); +} +Status WrapperOperation::SetAttrType(const char* attr_name, DataType value) { + return parent_op_->SetAttrType(attr_name, value); +} +Status WrapperOperation::SetAttrShape(const char* attr_name, + const int64_t* dims, const int num_dims) { + return parent_op_->SetAttrShape(attr_name, dims, num_dims); +} +Status WrapperOperation::SetAttrFunction(const char* attr_name, + const AbstractOperation* value) { + return parent_op_->SetAttrFunction(attr_name, value); +} +Status WrapperOperation::SetAttrFunctionName(const char* attr_name, + const char* value, size_t length) { + return parent_op_->SetAttrFunctionName(attr_name, value, length); +} +Status WrapperOperation::SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) { + return parent_op_->SetAttrTensor(attr_name, tensor); +} +Status WrapperOperation::SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values) { + return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values); +} +Status WrapperOperation::SetAttrFloatList(const char* attr_name, + const float* values, int num_values) { + return parent_op_->SetAttrFloatList(attr_name, values, num_values); +} +Status WrapperOperation::SetAttrIntList(const char* attr_name, + const int64_t* values, int num_values) { + return parent_op_->SetAttrIntList(attr_name, values, num_values); +} +Status WrapperOperation::SetAttrTypeList(const char* attr_name, + const DataType* values, + int num_values) { + return parent_op_->SetAttrTypeList(attr_name, values, num_values); +} +Status WrapperOperation::SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) { + return parent_op_->SetAttrBoolList(attr_name, values, num_values); +} +Status WrapperOperation::SetAttrShapeList(const char* attr_name, + const int64_t** dims, + const int* num_dims, int num_values) { + return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values); +} +Status WrapperOperation::SetAttrFunctionList( + const char* attr_name, absl::Span values) { + return parent_op_->SetAttrFunctionList(attr_name, values); +} +AbstractOperation* WrapperOperation::GetBackingOperation() { + return parent_op_; +} +Status WrapperOperation::Execute(absl::Span retvals, + int* num_retvals) { + return parent_op_->Execute(retvals, num_retvals); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/op_handler/wrapper_operation.h b/tensorflow/c/experimental/op_handler/wrapper_operation.h new file mode 100644 index 00000000000..b0ec9f174f0 --- /dev/null +++ b/tensorflow/c/experimental/op_handler/wrapper_operation.h @@ -0,0 +1,74 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_ + +#include "tensorflow/c/eager/abstract_operation.h" + +namespace tensorflow { + +// Forwards all of the AbstractOperation's methods to its wrapped operation. +// +// Useful as a base class to default to forwarding while adding some +// customization. +class WrapperOperation : public AbstractOperation { + public: + explicit WrapperOperation(AbstractOperation*, AbstractOperationKind kind); + void Release() override; + Status Reset(const char* op, const char* raw_device_name) override; + const string& Name() const override; + const string& DeviceName() const override; + Status SetDeviceName(const char* name) override; + Status AddInput(AbstractTensorHandle* input) override; + Status AddInputList(absl::Span inputs) override; + Status Execute(absl::Span retvals, + int* num_retvals) override; + Status SetAttrString(const char* attr_name, const char* data, + size_t length) override; + Status SetAttrInt(const char* attr_name, int64_t value) override; + Status SetAttrFloat(const char* attr_name, float value) override; + Status SetAttrBool(const char* attr_name, bool value) override; + Status SetAttrType(const char* attr_name, DataType value) override; + Status SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) override; + Status SetAttrFunction(const char* attr_name, + const AbstractOperation* value) override; + Status SetAttrFunctionName(const char* attr_name, const char* value, + size_t length) override; + Status SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) override; + Status SetAttrStringList(const char* attr_name, const void* const* values, + const size_t* lengths, int num_values) override; + Status SetAttrFloatList(const char* attr_name, const float* values, + int num_values) override; + Status SetAttrIntList(const char* attr_name, const int64_t* values, + int num_values) override; + Status SetAttrTypeList(const char* attr_name, const DataType* values, + int num_values) override; + Status SetAttrBoolList(const char* attr_name, const unsigned char* values, + int num_values) override; + Status SetAttrShapeList(const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values) override; + Status SetAttrFunctionList( + const char* attr_name, + absl::Span values) override; + AbstractOperation* GetBackingOperation(); + + private: + AbstractOperation* parent_op_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_ diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index 95bb12e8e50..462993e8918 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -11,17 +11,29 @@ package( licenses = ["notice"], # Apache 2.0 ) +cc_library( + name = "stream_executor_hdrs", + hdrs = ["stream_executor.h"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_status_headers", + ], +) + cc_library( name = "stream_executor", srcs = ["stream_executor.cc"], hdrs = ["stream_executor.h"], - visibility = ["//visibility:public"], + visibility = ["//tensorflow:internal"], deps = [ ":stream_executor_internal", "//tensorflow/c:c_api_macros", "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", "//tensorflow/core:lib", + "//tensorflow/core/platform:regexp", + "//tensorflow/core/platform:strcat", "//tensorflow/stream_executor:executor_cache", "//tensorflow/stream_executor:multi_platform_manager", "//tensorflow/stream_executor:platform", diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 09442a4f7b7..ec2bada791e 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -28,7 +28,10 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/platform.h" @@ -40,6 +43,8 @@ limitations under the License. using tensorflow::StatusFromTF_Status; namespace stream_executor { +using tensorflow::StringPiece; + namespace { #define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \ @@ -59,10 +64,35 @@ namespace { } \ } while (0) +port::Status ValidateDeviceType(StringPiece type) { + // Validate device type. Device type must start with a capital letter and + // consist of capital letters and underscores. Reasoning behind this decision: + // * At the minimum we want to disallow '/' and ':' since + // these characters are used in device spec, for e.g. + // /job:foo/replica:12/device:GPU:1. + // * Underscores seem useful, for e.g. XLA_GPU uses underscores. + // * Allowing lowercase might get confusing. For example, say someone + // registers a new type called "Gpu". It might be confusing for users that + // "Gpu" is not the same device type as "GPU". + // Note that lowercase "cpu" and "gpu" are currently supported only for + // legacy reasons: + // https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/framework/device_spec.py;l=46;drc=d3a378f9665d8eee827c74cb9ecbee81e4c288dd + static const LazyRE2 kTfDeviceTypeRegEx = {"[A-Z][A-Z_]*"}; + bool matches = RE2::FullMatch(type, *kTfDeviceTypeRegEx); + if (!matches) { + return port::FailedPreconditionError( + tensorflow::strings::StrCat("Device name/type '", type, "' must match ", + kTfDeviceTypeRegEx->pattern(), ".")); + } + return port::Status::OK(); +} + port::Status ValidateSPPlatform(const SP_Platform& platform) { VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE); VALIDATE_MEMBER(SP_Platform, platform, name); VALIDATE_MEMBER(SP_Platform, platform, type); + TF_RETURN_IF_ERROR(ValidateDeviceType(platform.name)); + TF_RETURN_IF_ERROR(ValidateDeviceType(platform.type)); // `visible_device_count` could be 0 at initialization time. return port::Status::OK(); } diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.h b/tensorflow/c/experimental/stream_executor/stream_executor.h index ba6b1c564a8..bec77ef520b 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor.h @@ -52,7 +52,7 @@ limitations under the License. // params.device = &device; // // /* Plugin code below */ -// constexpr char DEVICE_NAME[] = "MyDevice"; +// constexpr char DEVICE_NAME[] = "MY_DEVICE"; // constexpr char DEVICE_TYPE[] = "GPU"; // // void create_device(const SP_Platform* platform, @@ -416,10 +416,15 @@ typedef struct SP_Platform { void* ext; // free-form data set by plugin - // Platform name. Must be null-terminated. + // Platform name (also referred to as subtype), for example MY_DEVICE. + // The name must start with a capital letter and consist of + // capital letters and underscores. + // Must be null-terminated. const char* name; // Device type name, for example GPU. Must be null-terminated. + // The name must start with a capital letter and consist of + // capital letters and underscores. const char* type; // Number of visible devices diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc index b28d1f6fc6d..56c4ea09052 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc @@ -41,9 +41,9 @@ struct SP_Timer_st { namespace stream_executor { namespace { -constexpr int DEVICE_COUNT = 2; -constexpr char DEVICE_NAME[] = "MyDevice"; -constexpr char DEVICE_TYPE[] = "GPU"; +constexpr int kDeviceCount = 2; +constexpr char kDeviceName[] = "MY_DEVICE"; +constexpr char kDeviceType[] = "GPU"; /*** Create SP_StreamExecutor (with empty functions) ***/ void allocate(const SP_Device* const device, uint64_t size, @@ -190,9 +190,9 @@ void destroy_device_fns(const SP_Platform* platform, SP_DeviceFns* device_fns) { void PopulateDefaultPlatform(SP_Platform* platform, SP_PlatformFns* platform_fns) { *platform = {SP_PLATFORM_STRUCT_SIZE}; - platform->name = DEVICE_NAME; - platform->type = DEVICE_TYPE; - platform->visible_device_count = DEVICE_COUNT; + platform->name = kDeviceName; + platform->type = kDeviceType; + platform->visible_device_count = kDeviceCount; platform_fns->create_device = create_device; platform_fns->destroy_device = destroy_device; platform_fns->create_device_fns = create_device_fns; @@ -218,11 +218,11 @@ TEST(StreamExecutor, SuccessfulRegistration) { port::Status status = InitStreamExecutorPlugin(plugin_init); TF_ASSERT_OK(status); port::StatusOr maybe_platform = - MultiPlatformManager::PlatformWithName("MyDevice"); + MultiPlatformManager::PlatformWithName("MY_DEVICE"); TF_ASSERT_OK(maybe_platform.status()); Platform* platform = maybe_platform.ConsumeValueOrDie(); - ASSERT_EQ(platform->Name(), DEVICE_NAME); - ASSERT_EQ(platform->VisibleDeviceCount(), DEVICE_COUNT); + ASSERT_EQ(platform->Name(), kDeviceName); + ASSERT_EQ(platform->VisibleDeviceCount(), kDeviceCount); port::StatusOr maybe_executor = platform->ExecutorForDevice(0); @@ -244,6 +244,39 @@ TEST(StreamExecutor, NameNotSet) { ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set."); } +TEST(StreamExecutor, InvalidNameWithSemicolon) { + auto plugin_init = [](SE_PlatformRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultPlatform(params->platform, params->platform_fns); + params->platform->name = "INVALID:NAME"; + params->destroy_platform = destroy_platform; + params->destroy_platform_fns = destroy_platform_fns; + }; + + port::Status status = InitStreamExecutorPlugin(plugin_init); + ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); + EXPECT_THAT( + status.error_message(), + testing::ContainsRegex("Device name/type 'INVALID:NAME' must match")); +} + +TEST(StreamExecutor, InvalidNameWithSlash) { + auto plugin_init = [](SE_PlatformRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultPlatform(params->platform, params->platform_fns); + params->platform->name = "INVALID/"; + params->destroy_platform = destroy_platform; + params->destroy_platform_fns = destroy_platform_fns; + }; + + port::Status status = InitStreamExecutorPlugin(plugin_init); + ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); + EXPECT_THAT(status.error_message(), + testing::ContainsRegex("Device name/type 'INVALID/' must match")); +} + TEST(StreamExecutor, CreateDeviceNotSet) { auto plugin_init = [](SE_PlatformRegistrationParams* const params, TF_Status* const status) -> void { diff --git a/tensorflow/c/kernels/bitcast_op.cc b/tensorflow/c/kernels/bitcast_op.cc index c194dcd686b..c6468e0ab80 100644 --- a/tensorflow/c/kernels/bitcast_op.cc +++ b/tensorflow/c/kernels/bitcast_op.cc @@ -148,7 +148,7 @@ void RegisterBitcastOpKernel() { << "Error while registering bitcast kernel"; } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM { auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_GPU, &BitcastOp_Create, &BitcastOp_Compute, diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index c9df2cc34d1..5ddc9a46be1 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -352,7 +352,7 @@ class DeviceKernelOpTest : public OpsTestBase { EXPECT_EQ(TF_OK, TF_GetCode(status)); TF_DeleteStatus(status); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM std::unique_ptr device( DeviceFactory::NewDevice(device_name_, {}, "/job:a/replica:0/task:0")); OpsTestBase::SetDevice(DEVICE_GPU, std::move(device)); @@ -361,7 +361,7 @@ class DeviceKernelOpTest : public OpsTestBase { TF_ASSERT_OK(InitOp()); } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM const char* device_name_ = tensorflow::DEVICE_GPU; #else const char* device_name_ = tensorflow::DEVICE_CPU; @@ -468,7 +468,7 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempSizeOne) { int64_t dim = 1; TF_AllocatorAttributes alloc_attrs; alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM alloc_attrs.on_host = 0; #else alloc_attrs.on_host = 1; @@ -505,7 +505,7 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempEmpty) { int64_t dim = 0; TF_AllocatorAttributes alloc_attrs; alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM alloc_attrs.on_host = 0; #else alloc_attrs.on_host = 1; @@ -538,7 +538,7 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) { int64_t dim[2] = {2, 3}; TF_AllocatorAttributes alloc_attrs; alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM alloc_attrs.on_host = 0; #else alloc_attrs.on_host = 1; @@ -646,7 +646,7 @@ template void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes, TF_OpKernelContext* ctx) { T* data = reinterpret_cast(TF_TensorData(tensor)); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM OpKernelContext* cc_ctx = reinterpret_cast(ctx); cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, values, tensor_size_bytes); diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index bf5ada89cdd..8f7e447d322 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -251,7 +251,6 @@ cc_library_with_android_deps( deps = [ "//tensorflow/core:core_cpu", "//tensorflow/core:lib", - "//tensorflow/core:lib_experimental", "//tensorflow/core:protos_all_cc", ], ) @@ -266,7 +265,6 @@ tf_cc_test( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:lib_experimental", "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index e9173227aad..480243a29e6 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -15,13 +15,12 @@ limitations under the License. #include +#include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/cc/framework/grad_op_registry.h" -#include "tensorflow/cc/framework/gradients.h" - namespace tensorflow { namespace ops { namespace { @@ -90,15 +89,25 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad); -Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op, - const std::vector& grad_inputs, - std::vector* grad_outputs) { - grad_outputs->push_back(Identity(scope, grad_inputs[0])); - grad_outputs->push_back(NoGradient()); - grad_outputs->push_back(NoGradient()); +Status QuantizeAndDequantizeV4GradHelper(const Scope& scope, + const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + Input input = Shape(scope, op.input(0)); + Input input_min = op.input(1); + Input input_max = op.input(2); + int64 axis; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); + auto qdq_v4_grad = QuantizeAndDequantizeV4Grad( + scope, grad_inputs[0], input, input_min, input_max, + QuantizeAndDequantizeV4Grad::Axis(axis)); + grad_outputs->push_back(qdq_v4_grad.input_backprop); + grad_outputs->push_back(qdq_v4_grad.input_min_backprop); + grad_outputs->push_back(qdq_v4_grad.input_max_backprop); return scope.status(); } -REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad); +REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4", + QuantizeAndDequantizeV4GradHelper); Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 70d080a682f..dcd652d9fdf 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -404,10 +404,12 @@ Status RestoreSession(const RunOptions& run_options, const uint64 read_start_microseconds = Env::Default()->NowMicros(); std::vector asset_file_defs; TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs)); - TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir, - meta_graph.saver_def().restore_op_name(), - meta_graph.saver_def().filename_tensor_name(), - asset_file_defs, session->get())); + if (meta_graph.has_saver_def()) { + TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir, + meta_graph.saver_def().restore_op_name(), + meta_graph.saver_def().filename_tensor_name(), + asset_file_defs, session->get())); + } // Record walltime spent in restoring graph from disk, but postpone metric // increments until graph init finishes. const uint64 restore_graph_walltime = diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 7d87b5f0715..5c84eecd976 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -7,6 +7,9 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_ load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts") load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm") +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") + # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "filegroup") @@ -283,6 +286,7 @@ cc_library( "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", ], @@ -291,7 +295,7 @@ cc_library( # Header-only version of "flags" library, for linking from the shared object # without ODR violations. cc_library( - name = "flags_headers_only", + name = "flags_headers", hdrs = ["flags.h"], visibility = [":friends"], deps = [ @@ -302,6 +306,11 @@ cc_library( ], ) +cc_header_only_library( + name = "flags_headers_only", + deps = [":flags_headers"], +) + cc_library( name = "common", srcs = [ @@ -447,8 +456,8 @@ cc_library( # Header-only version of "flags" library, for linking from the shared object # without ODR violations. cc_library( - name = "get_compiler_ir_hdrs_only", - hdrs = ["get_compiler_ir.h"], + name = "get_compiler_ir_hdrs", + textual_hdrs = ["get_compiler_ir.h"], visibility = [ ":internal", "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", @@ -463,6 +472,23 @@ cc_library( ], ) +cc_header_only_library( + name = "get_compiler_ir_hdrs_only", + deps = [":get_compiler_ir_hdrs"], +) + +# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. +cc_header_only_library( + name = "xla_jit_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":xla_cpu_device", + ":xla_cpu_jit", + ":xla_gpu_device", + ":xla_gpu_jit", + ], +) + cc_library( name = "xla_kernel_creator", srcs = [ @@ -520,8 +546,8 @@ cc_library( hdrs = ["resource_operation_safety_analysis.h"], deps = [ ":xla_cluster_util", - "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/compiler/xla/service/graphcycles", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", @@ -692,7 +718,6 @@ cc_library( "//tensorflow/cc:ops", "//tensorflow/cc:scope", "//tensorflow/cc:scope_internal", - "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:side_effect_util", @@ -705,6 +730,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service/graphcycles", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -732,9 +758,9 @@ cc_library( deps = [ ":flags", ":xla_activity_proto_cc", - "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service/graphcycles", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -842,9 +868,12 @@ tf_cc_test( "partially_decluster_pass_test.cc", "rearrange_function_argument_pass_test.cc", ], - # TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value - # error. - tags = ["nomsan"] + tf_cuda_tests_tags(), + tags = [ + # TODO(b/141643254) Re-enable msan after fixing + # use-of-uninitialized-value error. + "nomsan", + "no_cuda_asan", # TODO(b/171317460): re-enable. + ] + tf_cuda_tests_tags(), deps = [ ":common", ":compilability_check_util", @@ -965,13 +994,13 @@ cc_library( ":xla_activity_listener", ":xla_activity_proto_cc", ":xla_cluster_util", - "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service/graphcycles", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:graph", @@ -1075,15 +1104,3 @@ cc_library( ], alwayslink = 1, ) - -# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. -cc_header_only_library( - name = "xla_jit_headers_lib", - visibility = ["//visibility:public"], - deps = [ - ":xla_cpu_device", - ":xla_cpu_jit", - ":xla_gpu_device", - ":xla_gpu_jit", - ], -) diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 62e121420c3..87b06c2ab36 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" @@ -42,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 65da072483b..224bedabd3b 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -24,11 +24,11 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index d482642b44c..fd55cab637c 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -27,11 +27,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index ee7daf092da..52d8fb94ff6 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -167,8 +167,16 @@ void AllocateAndParseFlags() { jitter_flags = new IntroduceFloatingPointJitterPassFlags; jitter_flags->jitter_amount = 1e-5; - mlir_flags = new MlirCommonFlags; - mlir_flags->tf_mlir_enable_mlir_bridge = false; + // The `enable_mlir_bridge` flag allows the user to explicitly request that + // their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge. + // + // The `enable_mlir_bridge_is_explicit` variable tracks whether or not the + // user has made an explicit request. That is, if this variable is set to + // true, the program honors the user's request as per `enable_mlir_bridge`; if + // it's set to false, the default behavior is used (which may run either + // bridge, on a per-graph basis). + bool enable_mlir_bridge = false; + bool enable_mlir_bridge_is_explicit = false; auto setter_for_jitter_tensor_names = [](string sequence) { jitter_flags->tensor_names = absl::StrSplit(sequence, ','); @@ -217,12 +225,24 @@ void AllocateAndParseFlags() { "The amount of jitter to introduce. This amount is added to each " "element in the tensors named in `tensor_names."), - Flag("tf_mlir_enable_mlir_bridge", - &mlir_flags->tf_mlir_enable_mlir_bridge, - "Enables experimental MLIR-Based TensorFlow Compiler Bridge.")}); + Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge, + "Enables experimental MLIR-Based TensorFlow Compiler Bridge.", + &enable_mlir_bridge_is_explicit)}); AppendMarkForCompilationPassFlagsInternal(flag_list); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); + + mlir_flags = new MlirCommonFlags; + if (!enable_mlir_bridge_is_explicit) { + mlir_flags->tf_mlir_enable_mlir_bridge = + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED; + } else if (enable_mlir_bridge) { + mlir_flags->tf_mlir_enable_mlir_bridge = + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; + } else { + mlir_flags->tf_mlir_enable_mlir_bridge = + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED; + } } } // namespace diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 5612b3b5864..a0860da7b04 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { @@ -135,7 +136,7 @@ struct IntroduceFloatingPointJitterPassFlags { // Flags for common MLIR configurations. struct MlirCommonFlags { - bool tf_mlir_enable_mlir_bridge; + ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge; }; // Return a pointer to the DumpGraphFlags struct; diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 12b40b1c83b..0f0f43cbad6 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -274,18 +274,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { run_options.set_allocator(allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); - xla::ThenExecuteFunction then_execute; - if (ctx->op_device_context()) { - then_execute = [&](se::Stream* stream, std::function fn) { - Status status = ctx->op_device_context()->ThenExecute( - down_cast(ctx->device()), stream, std::move(fn)); - if (!status.ok()) { - // This should never happen. - LOG(ERROR) << "ThenExecute failed " << status; - } - }; - run_options.set_then_execute_function(&then_execute); - } Env* env = Env::Default(); auto start_time = env->NowMicros(); @@ -522,18 +510,6 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { run_options.set_allocator(allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); - xla::ThenExecuteFunction then_execute; - if (ctx->op_device_context()) { - then_execute = [&](se::Stream* stream, std::function fn) { - Status status = ctx->op_device_context()->ThenExecute( - down_cast(ctx->device()), stream, std::move(fn)); - if (!status.ok()) { - // This should never happen. - LOG(ERROR) << "ThenExecute failed " << status; - } - }; - run_options.set_then_execute_function(&then_execute); - } Env* env = Env::Default(); auto start_time = env->NowMicros(); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 317e29d4a84..abd5d8d02f6 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -30,12 +30,12 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" @@ -1801,11 +1801,11 @@ absl::flat_hash_map>* GetAllowlistTable() { "Range", "Rank", "Reshape", "Shape", "ShapeN", "Size", "Squeeze", "Transpose", "ZerosLike", "OnesLike", "BiasAdd" /*PW + Broadcast*/, "BroadcastArgs", "BroadcastGradientArgs", "OneHot", "Concat", "ConcatV2", - "ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse", - "ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV", - "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", - "Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex", - "TensorStridedSliceUpdate", + "ConcatOffset", "Const", "MirrorPad", "MirrorPadGrad", "Pack", "Pad", + "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split", + "SplitV", "StridedSlice", "StridedSliceGrad", + "ResourceStridedSliceAssign", "Tile", "Transpose", "InvertPermutation", + "Unpack", "DeviceIndex", "TensorStridedSliceUpdate", }}}; // clang-format on return result; @@ -2061,11 +2061,13 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "XlaSelfAdjointEig", "XlaSend", "XlaSetBound", + "XlaSetDynamicDimensionSize", "XlaSharding", "XlaSort", "XlaSpmdFullToShardShape", "XlaSpmdShardToFullShape", "XlaSvd", + "XlaVariadicReduce", "XlaWhile", "Zeta", "_Arg", diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.h b/tensorflow/compiler/jit/resource_operation_safety_analysis.h index c652e5fe216..3931ae6c7cc 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.h +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ #define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index e2a1d159336..bf6dd5ab9f4 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -21,8 +21,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/graph/algorithm.h" diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index d7d5ee02265..435e3752b2e 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -283,25 +283,23 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); - bool has_tensor_list_arg = - absl::c_any_of(args, [](const XlaCompiler::Argument arg) { - return arg.kind == XlaCompiler::Argument::kTensorList; - }); const ConfigProto* config = ctx->function_library()->config_proto(); - bool use_mlir = config && config->experimental().enable_mlir_bridge(); + // TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR. + bool use_mlir = config && config->experimental().enable_mlir_bridge() && + node_def.op() != "VarIsInitializedOp"; #ifdef LIBTPU_ON_GCE - if (use_mlir && has_tensor_list_arg) { + if (use_mlir) { LOG(WARNING) << "MLIR is not supported in this environment."; } return compiler->CompileGraph(compile_options, node_def.name(), std::move(graph), args, result); #else - // TODO(b/155596779): Support TensorList args. - if (!use_mlir || !has_tensor_list_arg) { + if (!use_mlir) { return compiler->CompileGraph(compile_options, node_def.name(), std::move(graph), args, result); } + VLOG(1) << "Using MLIR bridge"; GraphDebugInfo debug_info; std::vector control_rets; if (result_dtypes.empty()) { diff --git a/tensorflow/compiler/jit/xla_compilation_cache_test.cc b/tensorflow/compiler/jit/xla_compilation_cache_test.cc index 5578925b790..e40d6221324 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache_test.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache_test.cc @@ -78,7 +78,9 @@ TEST(XlaCompilationCacheTest, TestDisabledXlaCompilation) { absl::StrContains(status.error_message(), "XLA compilation disabled")); } -static void BM_BuildSignature(int iters, int n_args) { +void BM_BuildSignature(::testing::benchmark::State& state) { + const int n_args = state.range(0); + NameAttrList fn; fn.set_name("afunction"); for (int i = 0; i < n_args; i++) { @@ -93,7 +95,7 @@ static void BM_BuildSignature(int iters, int n_args) { args[i].constant_value = Tensor(DT_INT32, {4, 0}); } - while (--iters > 0) { + for (auto i : state) { xla::StatusOr s = XlaCompilationCache::BuildSignature(fn, args); CHECK(s.ok()); diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index d4a69da4898..b90f8b7b990 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -89,7 +89,8 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr, XlaOpRegistry::RegisterCompilationKernels(); // Only check for compilability if the MLIR bridge is not enabled. - if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { + if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge != + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { std::vector diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index a0e60b1eafe..1c5581eb4ab 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -426,7 +426,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( ShapedBuffer buffer( xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}), xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}), - output.platform(), output.device_ordinal()); + output.device_ordinal()); buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(), /*source_base_index=*/{}, /*target_base_index=*/{0}); @@ -583,7 +583,11 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments( XlaCompiler::Argument& arg = out[input_num]; if (absl::c_binary_search(must_be_constant_idxs, input_num)) { // Handles compile-time constants. - TF_RET_CHECK(input->dtype() != DT_RESOURCE); + + // TODO(b/157241314): Support constants located in resource variables. + TF_RET_CHECK(input->dtype() != DT_RESOURCE) + << "tf2xla bridge does not support must-be-constants located in " + "resource variables; try moving them to a tensor"; arg.kind = XlaCompiler::Argument::kConstant; arg.type = input->dtype(); arg.shape = input->shape(); diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index 30e8d8a86a7..e1b81133724 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -45,6 +45,7 @@ filegroup( "include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", + "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td", "include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td", "include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td", @@ -122,8 +123,6 @@ gentbl( tbl_outs = [ ("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"), ("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"), - ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc"), - ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", @@ -150,6 +149,24 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", + td_relative_includes = [ + "include", + ], + td_srcs = [":hlo_ops_td_files"], +) + +gentbl( + name = "hlo_ops_base_structs_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc"), + ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.cc.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", + td_relative_includes = [ + "include", + ], td_srcs = [":hlo_ops_td_files"], ) @@ -194,6 +211,63 @@ gentbl( td_srcs = [":hlo_ops_td_files"], ) +gentbl( + name = "lhlo_gpu_ops_structs_inc_gen", + compatible_with = get_compatible_with_cloud(), + strip_include_prefix = "include", + tbl_outs = [ + ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc"), + ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td", + td_relative_includes = [ + "include", + ], + td_srcs = [ + ":hlo_ops_td_files", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td", + ], +) + +cc_library( + name = "lhlo_gpu_ops_structs", + srcs = [ + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc.inc", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc", + "lib/Dialect/mhlo/IR/lhlo_gpu_ops_structs.cc", + ], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h", + ], + includes = ["include"], + deps = [ + ":lhlo_gpu_ops_structs_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +gentbl( + name = "lhlo_gpu_ops_inc_gen", + compatible_with = get_compatible_with_cloud(), + strip_include_prefix = "include", + tbl_outs = [ + ("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc"), + ("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td", + td_relative_includes = [ + "include", + ], + td_srcs = [ + ":hlo_ops_td_files", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td", + ], +) + #TODO(aminim): revisit the naming and grouping of these rules post-move. gentbl( name = "canonicalize_inc_gen", @@ -251,6 +325,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "hlo_ops_base_structs", + srcs = [ + "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc", + "lib/Dialect/mhlo/IR/hlo_ops_base_structs.cc", + ], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h", + ], + includes = ["include"], + deps = [ + ":hlo_ops_base_structs_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "convert_op_folder", srcs = ["lib/utils/convert_op_folder.cc"], @@ -284,6 +375,7 @@ cc_library( ":chlo_ops_inc_gen", ":convert_op_folder", ":hlo_ops_base_inc_gen", + ":hlo_ops_base_structs", ":hlo_ops_inc_gen", ":infer_fusibility_op_interface", "@llvm-project//llvm:Support", @@ -314,6 +406,7 @@ cc_library( includes = ["include"], deps = [ ":hlo_ops_base_inc_gen", + ":hlo_ops_base_structs", ":lhlo_ops_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -330,6 +423,39 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "lhlo_gpu", + srcs = [ + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.cc.inc", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc", + "lib/Dialect/mhlo/IR/lhlo_gpu_ops.cc", + ], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h", + ], + includes = ["include"], + deps = [ + ":hlo", + ":hlo_ops_base_structs", + ":infer_fusibility_op_interface", + ":lhlo_gpu_ops_inc_gen", + ":lhlo_gpu_ops_structs", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:CopyOpInterface", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffects", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:ViewLikeInterface", + ], + alwayslink = 1, +) + cc_library( name = "hlo_dialect_registration", srcs = ["lib/Dialect/mhlo/IR/init.cc"], @@ -337,6 +463,7 @@ cc_library( deps = [ ":hlo", ":lhlo", + ":lhlo_gpu", "@llvm-project//mlir:IR", ], ) @@ -385,10 +512,20 @@ cc_library( ":lhlo", ":map_hlo_to_lhlo_op", "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", "@llvm-project//mlir:StandardOps", ], ) +cc_library( + name = "map_chlo_to_hlo_op", + hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"], + deps = [ + ":hlo", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "map_hlo_to_lhlo_op", hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"], @@ -410,6 +547,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TransformUtils", ], alwayslink = 1, ) @@ -477,9 +615,11 @@ cc_library( ], deps = [ ":hlo", + ":map_chlo_to_hlo_op", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", @@ -522,6 +662,7 @@ cc_library( "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:ViewLikeInterface", ], alwayslink = 1, ) @@ -635,6 +776,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", ], alwayslink = 1, ) @@ -762,6 +904,7 @@ cc_library( deps = [ ":chlo_legalize_to_hlo_inc_gen", ":hlo", + ":map_chlo_to_hlo_op", "@llvm-project//mlir:IR", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", @@ -878,6 +1021,7 @@ cc_binary( ":all_passes", ":hlo", ":lhlo", + ":lhlo_gpu", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/hlo/README.md b/tensorflow/compiler/mlir/hlo/README.md index 61517cd9fca..05aabe3f67e 100644 --- a/tensorflow/compiler/mlir/hlo/README.md +++ b/tensorflow/compiler/mlir/hlo/README.md @@ -22,7 +22,7 @@ upstream. ## QuickStart: building and testing -These instructions work on Linux, you may have to adjust for your plaform. +These instructions work on Linux, you may have to adjust for your platform. To build the code in this repository, you need a clone of the LLVM/MLIR git repository: diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt index 09bdca84cd3..3fa2b908d9c 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt @@ -25,7 +25,22 @@ function(add_mlir_hlo_dialect dialect dialect_namespace) endfunction() add_mlir_hlo_dialect(chlo_ops chlo) -add_mlir_hlo_dialect(hlo_ops mhlo) add_mlir_hlo_dialect(lhlo_ops lmhlo) +set(LLVM_TARGET_DEFINITIONS hlo_ops.td) +mlir_tablegen(hlo_ops.h.inc -gen-op-decls) +mlir_tablegen(hlo_ops.cc.inc -gen-op-defs) +mlir_tablegen(hlo_ops_base_structs.h.inc -gen-struct-attr-decls) +mlir_tablegen(hlo_ops_base_structs.cc.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRhlo_opsIncGen) + +set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td) +mlir_tablegen(lhlo_gpu_ops.h.inc -gen-op-decls) +mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs) +set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_structs.td) +mlir_tablegen(lhlo_gpu_ops_structs.h.inc -gen-struct-attr-decls) +mlir_tablegen(lhlo_gpu_ops_structs.cc.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen) +add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen) + add_mlir_interface(infer_fusibility_op_interface) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 13d5f02368b..a65d8258a51 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -89,10 +89,9 @@ class HLOClient_BroadcastBinaryElementwiseOp< OptionalAttr:$broadcast_dimensions ); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value left, Value right, " - "DenseIntElementsAttr broadcast_dimensions" - >]; + let builders = [ + OpBuilderDAG<(ins "Value":$left, "Value":$right, + "DenseIntElementsAttr":$broadcast_dimensions)>]; let results = (outs HLO_Tensor); @@ -427,7 +426,10 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp< string summary = "Compare operator (with optional broadcasting)"; string description = [{ - Compares `lhs` and `rhs` elementwise according to `comparison_direction`. + Compares `lhs` and `rhs` elementwise according to `comparison_direction` + and `compare_type`. If unspecified, `compare_type` is FLOAT for float element + types, SIGNED for signed element types and UNSIGNED for unsigned element + types. See https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. @@ -437,14 +439,15 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp< HLO_Tensor:$lhs, HLO_Tensor:$rhs, OptionalAttr:$broadcast_dimensions, - HLO_ComparisonDirectionAttr:$comparison_direction + HLO_ComparisonDirectionAttr:$comparison_direction, + OptionalAttr:$compare_type ); let results = (outs HLO_PredTensor); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " - "DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" - >]; + let builders = [ + OpBuilderDAG<(ins "Value":$lhs, "Value":$rhs, + "DenseIntElementsAttr":$broadcast_dimensions, + "StringAttr":$comparison_direction, CArg<"StringAttr", "{}">:$compare_type)>]; } #endif // CHLO_OPS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index 60ee4e613eb..b354189c12a 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" @@ -32,7 +33,7 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // clang-format off -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" // clang-format on diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 507f7c11d63..42db595634c 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -25,11 +25,6 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" -def HLO_Dialect : Dialect { - let name = "mhlo"; - let cppNamespace = "::mlir::mhlo"; -} - class HLO_Op traits> : Op { // Whether this operation has a custom conversion to HLO or not. @@ -63,9 +58,8 @@ def HLO_ConstOp : HLO_Op<"constant", HLO_StaticShapeTensor:$output ); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Attribute value" - >]; + let builders = [ + OpBuilderDAG<(ins "Attribute":$value)>]; let assemblyFormat = "attr-dict $value"; @@ -136,8 +130,8 @@ class HLO_UnaryElementwiseOp traits, } LogicalResult reifyReturnTypeShapes( OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { - return deriveShapeFromFirstOperand(&builder, getOperation(), - &reifiedReturnShapes); + return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); } bool inferInputOutputShapeEquality(int input, int output) { return true; @@ -152,9 +146,8 @@ class HLO_UnaryElementwiseOp traits, def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultShape], TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value operand" - >]; + let builders = [ + OpBuilderDAG<(ins "Value":$operand)>]; } def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt", @@ -167,10 +160,8 @@ def HLO_ConvertOp : HLO_UnaryElementwiseOp< "convert", [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, BASE_HLO_ConvertOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value operand, " - "Type result_element_ty" - >]; + let builders = [ + OpBuilderDAG<(ins "Value":$operand, "Type":$result_element_ty)>]; let hasFolder = 1; @@ -247,7 +238,9 @@ def HLO_RealOp: HLO_UnaryElementwiseOp<"real", } def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz", - [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp; + [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp { + let hasFolder = 1; +} def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, @@ -293,8 +286,8 @@ class HLO_BinaryElementwiseOp traits> : } LogicalResult reifyReturnTypeShapes( OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { - return deriveShapeFromFirstOperand(&builder, getOperation(), - &reifiedReturnShapes); + return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); } bool inferInputsShapeEquality(int lhs, int rhs) { return true; @@ -458,7 +451,7 @@ def HLO_SendOp : HLO_Op<"send", []> { let arguments = (ins HLO_TensorOrTuple:$operand, HLO_Token:$token, - ChannelHandle:$channel_id, + ChannelHandle:$channel_id, DefaultValuedAttr:$is_host_transfer ); @@ -483,7 +476,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> { let arguments = (ins HLO_Token:$token, - ChannelHandle:$channel_id, + ChannelHandle:$channel_id, DefaultValuedAttr:$is_host_transfer ); @@ -587,7 +580,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce", let arguments = (ins HLO_Tensor:$operand, I64ElementsAttr:$replica_groups, - OptionalAttr>:$channel_id + OptionalAttr:$channel_id ); let regions = (region SizedRegion<1>:$computation); let results = (outs HLO_Tensor); @@ -622,10 +615,9 @@ def HLO_ReduceOp: HLO_Op<"reduce", [ let results = (outs Variadic); - let builders = [OpBuilder< - "OpBuilder &, OperationState &state, ValueRange operands, " - "ValueRange init_values, DenseIntElementsAttr dimensions" - >]; + let builders = [ + OpBuilderDAG<(ins "ValueRange":$operands, "ValueRange":$init_values, + "DenseIntElementsAttr":$dimensions)>]; let extraClassDeclaration = [{ bool isFusibleWithConsumer() { @@ -661,18 +653,16 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO let hasFolder = 1; - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &results, " - "Value value, int32_t index">]; + let builders = [ + OpBuilderDAG<(ins "Value":$value, "int32_t":$index)>]; } def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { let arguments = (ins Variadic:$val); let results = (outs HLO_Tuple); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &results, " - "ValueRange values">]; + let builders = [ + OpBuilderDAG<(ins "ValueRange":$values)>]; let hasCanonicalizer = 1; } @@ -684,16 +674,19 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands, let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, - HLO_ComparisonDirectionAttr:$comparison_direction + HLO_ComparisonDirectionAttr:$comparison_direction, + OptionalAttr:$compare_type ); let results = (outs HLO_PredTensor); let hasFolder = 1; - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " - "StringAttr comparison_direction" - >]; + let builders = [ + OpBuilderDAG<(ins "Value":$lhs, "Value":$rhs, + "StringAttr":$comparison_direction, CArg<"StringAttr", "{}">:$compare_type)>, + ]; + + let hasCustomHLOConverter = 1; } //===----------------------------------------------------------------------===// @@ -703,7 +696,8 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands, def HLO_SliceOp: HLO_Op< "slice", [NoSideEffect, SameOperandsAndResultElementType, - AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { + AllTypesMatch<["start_indices", "limit_indices", "strides"]>, + DeclareOpInterfaceMethods]> { let arguments = (ins HLO_Tensor:$operand, I64ElementsAttr:$start_indices, @@ -715,21 +709,6 @@ def HLO_SliceOp: HLO_Op< let hasCanonicalizer = 1; let hasFolder = 1; - - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value operand, " - "DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, " - "DenseIntElementsAttr strides" - >]; - - let extraClassDeclaration = [{ - // Infers output type for given operand and attributes. Result type is - // unranked if any of the attributes is illegal. - static Type InferOutputTypes(Builder *builder, Value operand, - DenseIntElementsAttr start_indices, - DenseIntElementsAttr limit_indices, - DenseIntElementsAttr strides); - }]; } def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice", @@ -959,15 +938,6 @@ def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp { let results = (outs HLO_Tensor); } -def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [ - StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>, - StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>, - StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>, - StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr> - ]> { - let description = "Structure of dimension information for dot product"; -} - def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp { let arguments = (ins HLO_Tensor:$lhs, @@ -1029,14 +999,6 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp { let results = (outs HLO_Tensor); } -def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect, - [StructFieldAttr<"offset_dims", I64ElementsAttr>, - StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>, - StructFieldAttr<"start_index_map", I64ElementsAttr>, - StructFieldAttr<"index_vector_dim", I64Attr>]> { - let description = "Structure of dimension information for gather"; -} - def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { let arguments = (ins HLO_Tensor:$operand, @@ -1114,7 +1076,7 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>, HLO_Tensor:$operand, HLO_Tensor:$scatter_indices, HLO_Tensor:$updates, - ScatterDimensionNumbers:$scatter_dimension_numbers, + ScatterDimensionNumbers:$scatter_dimension_numbers, DefaultValuedAttr:$indices_are_sorted, DefaultValuedAttr:$unique_indices ); @@ -1124,6 +1086,8 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>, let results = (outs HLO_Tensor); let hasCustomHLOConverter = 1; + + let hasFolder = 1; } // TODO(jpienaar): Add broadcastable trait. @@ -1181,10 +1145,9 @@ def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShap let regions = (region SizedRegion<1>:$comparator); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &state, ValueRange operands, " - "int64_t dimension = -1, bool is_stable = false" - >]; + let builders = [ + OpBuilderDAG<(ins "ValueRange":$operands, CArg<"int64_t", "-1">:$dimension, + CArg<"bool", "false">:$is_stable)>]; // TODO(b/129422361): SortOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; @@ -1220,6 +1183,8 @@ def HLO_PadOp: HLO_Op<"pad", // TODO(b/129422361): PadOp has a custom constructor for HLO. let hasCustomHLOConverter = 1; + + let hasFolder = 1; } def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index cba2dc370f0..572a2f9dc07 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -18,6 +18,13 @@ limitations under the License. include "mlir/IR/OpBase.td" +def HLO_Dialect : Dialect { + let name = "mhlo"; + let cppNamespace = "::mlir::mhlo"; +} + +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td" + def HLO_Pred : TypeAlias; // TODO(hinsu): Use signed integers instead of signless integer which is being @@ -614,15 +621,6 @@ class BASE_HLO_CaseOp { // XLA parallelism related op definitions. //===----------------------------------------------------------------------===// -// Represents a unique identifier for each Send/Recv instruction pair or -// optionally for collective instructions (AllReduce, CollectivePermute, -// AllToAll). Non-positive channel_id handle is equivalent to no channel id. -class ChannelHandle : StructAttr<"ChannelHandle", dialect, [ - StructFieldAttr<"handle", I64Attr>, - StructFieldAttr<"type", I64Attr>]> { - let description = "two 64-bit integers 'handle' and 'type'"; -} - class BASE_HLO_ReplicaIdOp { string summary = "ReplicaId operator"; @@ -712,6 +710,7 @@ def HLO_PrecisionConfigAttr: OptionalAttr< TypedArrayAttrBase>; + //===----------------------------------------------------------------------===// // Fast Fourier Transform Type enum definitions. //===----------------------------------------------------------------------===// @@ -750,11 +749,30 @@ def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection", HLO_COMPARISON_DIRECTION_LT ]>; +def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"StringAttr()">; +def HLO_COMPARISON_TYPE_FLOAT : StrEnumAttrCase<"FLOAT">; +def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : StrEnumAttrCase<"TOTALORDER">; +def HLO_COMPARISON_TYPE_SIGNED : StrEnumAttrCase<"SIGNED">; +def HLO_COMPARISON_TYPE_UNSIGNED : StrEnumAttrCase<"UNSIGNED">; + +def HLO_ComparisonTypeAttr : StrEnumAttr<"ComparisonType", + "Which comparison type to use.", + [ + HLO_COMPARISON_TYPE_FLOAT, + HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER, + HLO_COMPARISON_TYPE_SIGNED, + HLO_COMPARISON_TYPE_UNSIGNED + ]>; + + class BASE_HLO_CompareOp { string summary = "Comparison operator"; string description = [{ - Compares `lhs` and `rhs` elementwise according to `comparison_direction`. + Compares `lhs` and `rhs` elementwise according to `comparison_direction` + and `compare_type`. If unspecified, `compare_type` is FLOAT for float element + types, SIGNED for signed element types and UNSIGNED for unsigned element + types. See https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. @@ -1011,21 +1029,6 @@ class BASE_HLO_ConcatenateOp { // Common convolution attributes //===----------------------------------------------------------------------===// -class ConvDimensionNumbersBase - : StructAttr<"ConvDimensionNumbers", dialect, [ - StructFieldAttr<"input_batch_dimension",I64Attr>, - StructFieldAttr<"input_feature_dimension", I64Attr>, - StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"output_batch_dimension", I64Attr>, - StructFieldAttr<"output_feature_dimension", I64Attr>, - StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { - - let description = "Structure of dimension information for conv op"; -} - class ConvolutionAttributes { dag attributes = (ins // Default value: one for each of the spatial dimension. @@ -1036,7 +1039,7 @@ class ConvolutionAttributes { OptionalAttr:$lhs_dilation, // Default value: one for each of the spatial dimension. OptionalAttr:$rhs_dilation, - ConvDimensionNumbersBase:$dimension_numbers, + ConvDimensionNumbers:$dimension_numbers, I64Attr:$feature_group_count, I64Attr:$batch_group_count, HLO_PrecisionConfigAttr:$precision_config @@ -1164,15 +1167,6 @@ class BASE_HLO_ReshapeOp { }]; } -class ScatterDimensionNumbers : StructAttr< - "ScatterDimensionNumbers", dialect, [ - StructFieldAttr<"update_window_dims", I64ElementsAttr>, - StructFieldAttr<"inserted_window_dims", I64ElementsAttr>, - StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>, - StructFieldAttr<"index_vector_dim", I64Attr>]> { - let description = "Structure of dimension information for scatter"; -} - class BASE_HLO_ScatterOp { string summary = "Scatter operator"; @@ -1388,7 +1382,7 @@ class BASE_HLO_BitcastOp { string description = [{ This op changes the shape of the input in the way that the physical - arranggment of elements are unchanged. + arrangement of elements are unchanged. However, the op needs layout information to make sense of "physical arrangement of elements". Layout support in MHLO is currently under diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h new file mode 100644 index 00000000000..3b78ff8a367 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines structures used in MHLO and LMHLO. + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +// Order matters, this .inc header is not self-contained, and relies on the +// #includes above. +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td new file mode 100644 index 00000000000..d25eb5104c6 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td @@ -0,0 +1,73 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef HLO_OPS_BASE_STRUCTS +#define HLO_OPS_BASE_STRUCTS + +//===----------------------------------------------------------------------===// +// Dot dimensions enum definitions. +//===----------------------------------------------------------------------===// + +def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>, + StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>, + StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>, + StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr> + ]> { + let description = "Structure of dimension information for dot product"; +} + +def ScatterDimensionNumbers : StructAttr< + "ScatterDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"update_window_dims", I64ElementsAttr>, + StructFieldAttr<"inserted_window_dims", I64ElementsAttr>, + StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>, + StructFieldAttr<"index_vector_dim", I64Attr>]> { + let description = "Structure of dimension information for scatter"; +} + +def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"input_batch_dimension",I64Attr>, + StructFieldAttr<"input_feature_dimension", I64Attr>, + StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"output_batch_dimension", I64Attr>, + StructFieldAttr<"output_feature_dimension", I64Attr>, + StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { + + let description = "Structure of dimension information for conv op"; +} + +def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect, + [StructFieldAttr<"offset_dims", I64ElementsAttr>, + StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>, + StructFieldAttr<"start_index_map", I64ElementsAttr>, + StructFieldAttr<"index_vector_dim", I64Attr>]> { + let description = "Structure of dimension information for gather"; +} + + +// Represents a unique identifier for each Send/Recv instruction pair or +// optionally for collective instructions (AllReduce, CollectivePermute, +// AllToAll). Non-positive channel_id handle is equivalent to no channel id. +def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [ + StructFieldAttr<"handle", I64Attr>, + StructFieldAttr<"type", I64Attr>]> { + let description = "two 64-bit integers 'handle' and 'type'"; +} + +#endif // HLO_OPS_BASE_STRUCTS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h new file mode 100644 index 00000000000..effa9ecc83b --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the LHLO dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" + +namespace mlir { +class OpBuilder; +} // namespace mlir + + +namespace mlir { +namespace lmhlo_gpu { + +class LmhloGpuDialect : public Dialect { + public: + explicit LmhloGpuDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "lmhlo_gpu"; } +}; + +} // namespace lmhlo_gpu +} // end namespace mlir + +#define GET_OP_CLASSES +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td new file mode 100644 index 00000000000..b3708bf4ff1 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td @@ -0,0 +1,210 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the operation definition file for LHMLO level GPU operations. +// Because these are LMHLO level operations, they operate on memrefs. + +#ifndef LHLO_GPU_OPS +#define LHLO_GPU_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td" + + +class LHLOGPU_Op traits = []> : + Op], traits)>; + +// Type for scratch buffers used by GPU library calls (memref) +def UntypedBuffer : MemRefRankOf<[I8], [1]>; + +// Cholesky info output buffer type. +def I32Buffer : MemRefOf<[I32]>; + +//===----------------------------------------------------------------------===// +// LMHLO ops representing batch norm library functions. +//===----------------------------------------------------------------------===// + +// Note: these are semantically different from similar LHLO as the GPU library +// calls generate or consume standard deviation, whereas LHLO ops generate or +// consume variance (= std-dev ^ 2). + +def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">, + BASE_HLO_BatchNormGradOp { + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$mean, + Arg:$stddev, + Arg:$grad_output, + Arg:$grad_operand, // gradient of $operand. + Arg:$grad_scale, + Arg:$grad_offset, + F32Attr:$epsilon, + I64Attr:$feature_index + ); +} + +def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">, + BASE_HLO_BatchNormInferenceOp { + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$offset, + Arg:$mean, + Arg:$stddev, + Arg:$output, + F32Attr:$epsilon, + I64Attr:$feature_index); +} + +def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">, + BASE_HLO_BatchNormTrainingOp { + + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$offset, + Arg:$output, + Arg:$batch_mean, + Arg:$batch_stddev, + F32Attr:$epsilon, + I64Attr:$feature_index + ); +} + +//===----------------------------------------------------------------------===// +// LMHLO ops representing convolution library functions. +//===----------------------------------------------------------------------===// + +def ActivationModeNone : StrEnumAttrCase<"None">; +def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">; +def ActivationModeTanh : StrEnumAttrCase<"Relu">; +def ActivationModeRelu : StrEnumAttrCase<"Relu">; +def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">; +def ActivationModeReluX : StrEnumAttrCase<"ReluX">; +def ActivationModeBandPass : StrEnumAttrCase<"BandPass">; + +def ActivationAttr : StrEnumAttr<"Activation", + "Activation applied with fused convolution", + [ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh, + ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX, + ActivationModeBandPass]>; + +def GpuConvolutionAttributes { + dag attributes = !con( + ConvolutionAttributes.attributes, + (ins F64Attr:$result_scale), + (ins ConvolutionBackendConfigAttr:$backend_config)); +} + +def GpuFusedConvolutionAttributes { + dag attributes = !con( + ConvolutionAttributes.attributes, + (ins F64Attr:$result_scale, + ActivationAttr:$activation_mode, + F64Attr:$side_input_scale), + (ins ConvolutionBackendConfigAttr:$backend_config)); +} + +def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> { + let arguments = !con( + (ins + Arg:$input, + Arg:$filter, + Arg:$output, + Arg:$scratch), + GpuConvolutionAttributes.attributes); +} + +def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> { + let arguments = !con( + (ins + Arg:$d_output, + Arg:$filter, + Arg:$d_input, + Arg:$scratch), + GpuConvolutionAttributes.attributes); +} + +def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> { + let arguments = !con( + (ins + Arg:$input, + Arg:$d_output, + Arg:$d_filter, + Arg:$scratch), + GpuConvolutionAttributes.attributes); +} + +// output = activation(result_scale * conv(input, filter) + +// side_input * side_input_scale + +// bias) +def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> { + let arguments = !con( + (ins + Arg:$input, + Arg:$filter, + Arg:$bias, + Arg:$side_input, + Arg:$output, + Arg:$scratch), + GpuFusedConvolutionAttributes.attributes); +} + +//===----------------------------------------------------------------------===// +// LMHLO ops representing other library functions. +//===----------------------------------------------------------------------===// + +// output = alpha * (lhs * rhs) +// Verify: beta = 0.0 +def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { + let arguments = (ins + Arg:$lhs, + Arg:$rhs, + Arg:$output, + DotDimensionNumbers:$dot_dimension_numbers, + F64Attr:$alpha, + I64Attr:$batch_size, + I64Attr:$algorithm); +} + +// output = alpha(lhs * rhs) + beta * bias +def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> { + let arguments = (ins + Arg:$lhs, + Arg:$rhs, + Arg:$bias, + Arg:$output, + DotDimensionNumbers:$dot_dimension_numbers, + F64Attr:$alpha, + F64Attr:$beta, + I64Attr:$batch_size, + I64Attr:$algorithm); +} + +def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { + let arguments = (ins + Arg:$input, + Arg:$output, + Arg:$scratch, + Arg:$info, + BoolAttr:$is_upper); +} + +#endif // LHLO_GPU_OPS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td new file mode 100644 index 00000000000..820e4ce64b0 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// We define the dialect here so that both structs and ops can refer to it. + +#ifndef LHLO_GPU_OPS_BASE +#define LHLO_GPU_OPS_BASE + +include "mlir/IR/OpBase.td" + +def LHLO_GPU_Dialect : Dialect { + let name = "lmhlo_gpu"; + let cppNamespace = "::mlir::lmhlo_gpu"; +} + +#endif // LHLO_GPU_OPS_BASE diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h new file mode 100644 index 00000000000..ff642b82c22 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ==============================================================================*/ + +// This file defines structures used in the LMHLO_GPU dialect. + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +// Order matters, this .inc header is not self-contained, and relies on the +// #includes above. +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc" + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td new file mode 100644 index 00000000000..2236fc38e29 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td @@ -0,0 +1,29 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef LHLO_GPU_OPS_STRUCTS +#define LHLO_GPU_OPS_STRUCTS + +include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td" + +def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig", + LHLO_GPU_Dialect, [ + StructFieldAttr<"algorithm", I64Attr>, + StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> { + let description = "GPU Convolution backend configuration"; +} + +#endif // LHLO_GPU_OPS_STRUCTS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index cc24e17c001..9dc6d7aa0c0 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines the operations used in the LXLA dialect. +// This file defines the operations used in the LHLO dialect. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #include "llvm/ADT/StringRef.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" @@ -33,11 +34,6 @@ limitations under the License. namespace mlir { class OpBuilder; -} // namespace mlir - -#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc" - -namespace mlir { namespace lmhlo { class LmhloDialect : public Dialect { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index c013939c544..32901f47dbe 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -284,7 +284,8 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp { Arg:$rhs, Arg:$out, OptionalAttr:$broadcast_dimensions, - HLO_ComparisonDirectionAttr:$comparison_direction + HLO_ComparisonDirectionAttr:$comparison_direction, + OptionalAttr:$compare_type ); } @@ -340,7 +341,8 @@ def HLO_StaticMemRefCastOp: Op:$operand); let results = (outs Res:$result); - let builders = [OpBuilder<"MemRefType resultType, Value operand", + let builders = [ + OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$operand), [{ $_state.addOperands(operand); $_state.types.push_back(resultType); @@ -386,8 +388,9 @@ def HLO_DynamicMemRefCastOp: Op:$result); let builders = [ - OpBuilder<"MemRefType resultType, Value operand, ValueRange sizes, " - "ValueRange strides", [{ + OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$operand, + "ValueRange":$sizes, "ValueRange":$strides), + [{ $_state.addOperands(operand); $_state.addOperands(sizes); $_state.addOperands(strides); @@ -592,6 +595,7 @@ def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp { let arguments = (ins Arg:$lhs, Arg:$rhs, + DotDimensionNumbers:$dot_dimension_numbers, HLO_PrecisionConfigAttr:$precision_config, Arg:$output ); @@ -601,11 +605,8 @@ def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp { let arguments = (ins Arg:$operand, Arg:$start_indices, - I64Attr:$index_vector_dim, - I64ElementsAttr:$offset_dims, + GatherDimensionNumbers:$dimension_numbers, I64ElementsAttr:$slice_sizes, - I64ElementsAttr:$collapsed_slice_dims, - I64ElementsAttr:$start_index_map, Arg:$output ); } @@ -623,7 +624,7 @@ def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp { Arg:$scatter_indices, Arg:$updates, Arg:$output, - ScatterDimensionNumbers:$scatter_dimension_numbers, + ScatterDimensionNumbers:$scatter_dimension_numbers, DefaultValuedAttr:$indices_are_sorted, DefaultValuedAttr:$unique_indices ); @@ -699,7 +700,7 @@ def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>, Arg:$output, I64ElementsAttr:$replica_groups, DefaultValuedAttr:$constrain_layout, - OptionalAttr>:$channel_id, + OptionalAttr:$channel_id, DefaultValuedAttr:$use_global_device_ids ); let regions = (region SizedRegion<1>:$computation); @@ -712,7 +713,7 @@ def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>, Arg:$operand, Arg:$output, I64ElementsAttr:$source_target_pairs, - OptionalAttr>:$channel_id + OptionalAttr:$channel_id ); } @@ -814,7 +815,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">] let skipDefaultBuilders = 1; let builders = [ - OpBuilder<"ArrayRef attributes"> + OpBuilderDAG<(ins "ArrayRef":$attributes)> ]; } @@ -824,9 +825,9 @@ def TerminatorOp : let description = [{ Terminator operation for the LHLO dialect. }]; - let builders = [OpBuilder<"ValueRange operands", - [{ build($_builder, $_state, llvm::None, operands, llvm::None); }] - >]; + let builders = [ + OpBuilderDAG<(ins "ValueRange":$operands), + [{ build($_builder, $_state, llvm::None, operands, llvm::None); }]>]; } #endif // LHLO_OPS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h new file mode 100644 index 00000000000..316e65076ae --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h @@ -0,0 +1,97 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_ + +#include + +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace chlo { + +struct HloComplexAdaptor { + static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs); + } +}; +template +struct HloBinaryElementwiseAdaptor { + static ToOpTy CreateOp(FromOpTy from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs); + } +}; +struct HloCompareAdaptor { + static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create( + from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs, + from_op.comparison_direction(), from_op.compare_typeAttr()); + } +}; + +// Populate a pattern for each Broadcasting CHlo op. This requires the pattern +// to take a ChloOpTy, MhloOpTy, and an Adaptor as templated values. +template